def test_proper_torsion(): # proper torsions have a variadic number of terms patterns = [ ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]], ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]], ['[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]], ] smirks = [x[0] for x in patterns] params = [x[1] for x in patterns] props = None pth = bonded.ProperTorsionHandler(smirks, params, props) mol = Chem.MolFromSmiles("FC(Br)=C(Br)F") torsion_params, torsion_idxs = pth.parameterize(mol) assert torsion_idxs.shape == (8, 4) assert torsion_params.shape == (8, 3) torsion_params_new, torsion_vjp_fn, torsion_idxs_new = jax.vjp(functools.partial(pth.partial_parameterize, mol=mol), pth.params, has_aux=True) np.testing.assert_array_equal(torsion_params_new, torsion_params) np.testing.assert_array_equal(torsion_idxs_new, torsion_idxs) torsion_param_adjoints = np.random.randn(*torsion_params.shape) ff_adjoints = torsion_vjp_fn(torsion_param_adjoints)[0] mask = np.argwhere(torsion_params > 90) assert np.all(ff_adjoints[mask] == 0.0) == True
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), **kwargs): rng = kwargs.pop('rng', None) rngs = (None, ) * self._n_layers if rng is not None: rngs = random.split(rng, self._n_layers) def call_compute_residual(x, weights): res = self.compute_residual(x, weights=weights, state=state[0], rng=rngs[0], **kwargs) return res assert len(ct) == 2 ct = ((ct[0], ct[0], ct[1])) stack_with_residual, vjpfun = jax.vjp(call_compute_residual, output, weights[0]) reconstructed_x = self.subtract_top(stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1], **kwargs) x_ct, residual_weights_ct = vjpfun(ct) assert not jax.tree_util.tree_leaves(weights[-1]) add_top_weights_ct = weights[-1] return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])
def parameterize(self, mol): """ Parameters ---------- mol: Chem.ROMol molecule to be parameterized. Returns ------- tuple (parameters of shape [N,2], vjp_fn) """ param_idxs = generate_nonbonded_idxs(mol, self.smirks) def param_fn(p): params = p[param_idxs] sigmas = params[:, 0] epsilons = jnp.power(params[:, 1], 2) # resolves a super annoying singularity return jnp.stack([sigmas, epsilons], axis=1) return jax.vjp(param_fn, self.params)
def parameterize(self, mol): """ Parameterize given molecule Parameters ---------- mol: Chem.ROMol rdkit molecule, should have hydrogens pre-added Returns ------- tuple of (Q,2) (np.int32), ((Q,2), fn: R^Qx2 -> R^Px2)) System bond idxes, parameters, and the vjp_fn. """ bond_idxs, param_idxs = generate_vd_idxs(mol, self.smirks) param_fn = functools.partial(parameterize_ligand, param_idxs=param_idxs) sys_params, vjp_fn = jax.vjp(param_fn, self.params) return np.array(bond_idxs, dtype=np.int32), (np.array(sys_params, dtype=np.float64), vjp_fn)
def critic_loss( critic_params: hk.Params, generator_params: hk.Params, img_batch: ImgBatch, latent_batch: LatentBatch, ) -> Tuple[jnp.ndarray, Log]: batch_size = img_batch.shape[0] img_generated = generator.apply(generator_params, latent_batch) def f(x): return critic.apply(critic_params, x) f_real = f(img_batch) # Use the vector-jacobian product to efficiently compute the grad for each f_i. f_gen, grad_fn = jax.vjp(f, img_generated) grad = grad_fn(jnp.ones(f_gen.shape))[0] assert_shape(f_real, (batch_size, 1)) assert_shape(f_gen, (batch_size, 1)) assert_shape(grad, img_batch.shape) flat_grad = grad.reshape(batch_size, -1) gp = jnp.square(1 - jnp.linalg.norm(flat_grad, axis=1)) assert_shape(gp, (batch_size, )) loss = jnp.mean(f_gen) - jnp.mean(f_real) gradient_penalty = jnp.mean(gp) log = { "wasserstein": -loss, "gradient_penalty": gradient_penalty, } return loss + 10 * gradient_penalty, log
def f(x): y = x + 1 p, vjp_f = vjp( lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y) return vjp_f(p)
def _helper(cond): outp, _ = jax.vjp(lambda x: _bijector(root, x)[0], cond) _, vjp_fun = jax.vjp(lambda x: _bijector(x, cond)[0], root) jac = vjp_fun(jnp.ones_like(outp))[0] return (outp, outp, jac)
def _call(x, *params): y, vjp_bijector = jax.vjp(lambda x: bijector(x, *params), x) ldj = jnp.log(vjp_bijector(jnp.ones_like(y))[0]) return y, ldj
def uoro_grad(key, theta, x, t0, T, K, s_tilde=None, theta_tilde=None): epsilon_perturbation = 1e-7 epsilon_stability = 1e-7 t_current = t0 mystate = x total_theta_grad = 0 total_loss = 0.0 if s_tilde is None: s_tilde = jnp.zeros(mystate.shape) if theta_tilde is None: theta_tilde = jnp.zeros(theta.shape) for i in range(K): state_vec_old = mystate state_new = single_step(theta, mystate, t_current, T) total_loss += loss(state_new) * (t_current < T) state_vec_new = state_new dl_dstate_old = compute_dL_dstate_old(theta, state_vec_old, t_current, T) dl_dtheta_direct = compute_dL_dtheta_direct(theta, state_vec_old, t_current, T) indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde pseudograds = indirect_grad + dl_dtheta_direct state_old_perturbed = state_vec_old + s_tilde * epsilon_perturbation state_vec_new_perturbed = single_step(theta, state_old_perturbed, t_current, T) state_deriv_in_direction_s_tilde = ( state_vec_new_perturbed - state_vec_new) / epsilon_perturbation key, skey = jax.random.split(key) nus = jnp.round(jax.random.uniform(skey, state_vec_old.shape)) * 2 - 1 custom_f = lambda param_vector: single_step( param_vector, state_vec_old, t_current, T) primals, f_vjp = jax.vjp(custom_f, theta) direct_theta_tilde_contribution, = f_vjp(nus) rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) / (jnp.linalg.norm(state_deriv_in_direction_s_tilde) + epsilon_stability)) rho_1 = jnp.sqrt( (jnp.linalg.norm(direct_theta_tilde_contribution) + epsilon_stability) / (jnp.linalg.norm(nus) + epsilon_stability)) theta_grad = pseudograds total_theta_grad += theta_grad s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1 mystate = state_new t_current += 1 return (key, total_loss, mystate, s_tilde, theta_tilde), total_theta_grad
def body_fun(vals): """Performs attention for a single batch element and head.""" batch_loop_idx = vals[0] if self._prng is None: hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx) hash_rng, slice_rng = backend.random.split(hash_slice_rng) else: # TODO(kitaev): Maybe use the same RNG across examples (but not heads)? hash_rng = jax.random.fold_in(self._prng, batch_loop_idx) slice_rng = jax.random.fold_in(rng, batch_loop_idx) qk_slice = jax.lax.dynamic_index_in_dim(qk, batch_loop_idx, axis=0, keepdims=False) v_slice = jax.lax.dynamic_index_in_dim(v, batch_loop_idx, axis=0, keepdims=False) if buckets is None: buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng) else: buckets_slice = jax.lax.dynamic_index_in_dim(buckets, batch_loop_idx, axis=0, keepdims=False) if ct is None: out_slice = self.single_call(qk_slice, v_slice, buckets_slice, rng=slice_rng) else: def _do_single_call(qk_slice, v_slice): return self.single_call(qk_slice, v_slice, buckets_slice, rng=slice_rng) ct_slice = jax.lax.dynamic_index_in_dim(ct, batch_loop_idx, axis=0, keepdims=False) out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice) qk_ct_slice, v_ct_slice = vjpfun(ct_slice) new_vals = (batch_loop_idx + 1, ) if return_output: out_accum = vals[1] out_accum = jax.lax.dynamic_update_index_in_dim(out_accum, out_slice, batch_loop_idx, axis=0) new_vals = new_vals + (out_accum, ) if return_state: buckets_accum = vals[2] buckets_accum = jax.lax.dynamic_update_index_in_dim( buckets_accum, buckets_slice, batch_loop_idx, axis=0) new_vals = new_vals + (buckets_accum, ) if ct is not None: qk_ct_accum, v_ct_accum = vals[-2:] qk_ct_accum = jax.lax.dynamic_update_index_in_dim( qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0) v_ct_accum = jax.lax.dynamic_update_index_in_dim( v_ct_accum, v_ct_slice, batch_loop_idx, axis=0) new_vals = new_vals + (qk_ct_accum, v_ct_accum) return new_vals
def binned_attn_vjp(sq, sk, sv, so_ct): so, vjpfun = jax.vjp(binned_attn, sq, sk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct
def parameterize(self, mol): """ Parameters ---------- mol: Chem.ROMol molecule to be parameterized. """ # imported here for optional dependency from openeye import oechem from openeye import oequacpac mb = Chem.MolToMolBlock(mol) ims = oechem.oemolistream() ims.SetFormat(oechem.OEFormat_SDF) ims.openstring(mb) for buf_mol in ims.GetOEMols(): oemol = oechem.OEMol(buf_mol) # AromaticityModel.assign(oe_molecule, bcc_collection.aromaticity_model) AromaticityModel.assign(oemol) # check for cache cache_key = 'AM1Cache' if not mol.HasProp(cache_key): result = oequacpac.OEAssignCharges( oemol, oequacpac.OEAM1Charges(symmetrize=True)) if result is False: raise Exception('Unable to assign charges') am1_charges = [] for index, atom in enumerate(oemol.GetAtoms()): q = atom.GetPartialCharge() * np.sqrt(constants.ONE_4PI_EPS0) am1_charges.append(q) mol.SetProp(cache_key, base64.b64encode(pickle.dumps(am1_charges))) else: am1_charges = pickle.loads(base64.b64decode( mol.GetProp(cache_key))) bond_idxs = [] bond_idx_params = [] for index in range(len(self.smirks)): smirk = self.smirks[index] param = self.params[index] substructure_search = oechem.OESubSearch(smirk) substructure_search.SetMaxMatches(0) matched_bonds = [] matches = [] for match in substructure_search.Match(oemol): matched_indices = { atom_match.pattern.GetMapIdx() - 1: atom_match.target.GetIdx() for atom_match in match.GetAtoms() if atom_match.pattern.GetMapIdx() != 0 } matches.append(matched_indices) for matched_indices in matches: forward_matched_bond = [matched_indices[0], matched_indices[1]] reverse_matched_bond = [matched_indices[1], matched_indices[0]] if (forward_matched_bond in matched_bonds or reverse_matched_bond in matched_bonds or forward_matched_bond in bond_idxs or reverse_matched_bond in bond_idxs): continue matched_bonds.append(forward_matched_bond) bond_idxs.append(forward_matched_bond) bond_idx_params.append(index) bcc_fn = functools.partial(apply_bcc, bond_idxs=np.array(bond_idxs), bond_idx_params=np.array(bond_idx_params, dtype=np.int32), am1_charges=np.array(am1_charges)) charges, vjp_fn = jax.vjp(bcc_fn, self.params) return np.array(charges, dtype=np.float64), vjp_fn
def cost_func_vjp(bb, u): _, jax_vjp = jax.vjp(cost_func, bb, u) directmat = jax_vjp(1.0) return directmat[0]
# Apply network to dummy inputs inputs = np.zeros((1, 32)) # predictions = net_apply(net_params, inputs) # print ("pred: ", predictions) def net_apply_reverse(inputs, net_params): return net_apply(net_params, inputs) @jit def test_loss(net_params, inputs): return np.sum(net_apply(net_params, inputs)) primals_out, vjpfun = vjp(partial(net_apply_reverse, inputs), net_params) print(primals_out) primals_out, jvpfun = linearize(partial(net_apply_reverse, inputs), net_params) # primals_out, vp = jvp(net_apply, (net_params, inputs), random.normal(rng, (1, 256))) print(primals_out) input("") for i in range(10): import time s = time.time() out = vjpfun(random.normal(rng, (1, 10))) e = time.time() print("vjp time: ", (e - s)) s = time.time()
def rmatvec(v): _, vjp_fn = vjp(self.c, z) return vjp_fn(v)[0]
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), **kwargs): rngs = _pop_rng_and_split(kwargs, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn( x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = _inputs_from_stack(self.compute_residual, stack) outputs, compute_residual_vjpfun, outputs_aux = jax.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = _outputs_onto_stack(self.compute_residual, outputs, stack) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = _inputs_from_stack(self.attention_layer, stack) (residual, _, attn_inputs_ct, attn_weights_ct ) = self.attention_layer.forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = _outputs_onto_stack( self.attention_layer, attn_inputs_ct, stack_ct, self.attention_layer.n_out, self.attention_layer.n_in) compute_residual_ct = _inputs_from_stack( self.compute_residual, stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct,) compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = _outputs_onto_stack( self.compute_residual, compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out, self.compute_residual.n_in) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct,) stack_ct = (accumulator_output_ct,) + jax.tree_multimap( lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct ) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x,) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct,) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)
def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): output_slice, vjpfun = jax.vjp(forward_slice, query_slice, q_loop_idx, key, value) return output_slice, vjpfun(ct_slice)
def f_vjp(args, cts): res, pullback = jax.vjp(f, *args) return pullback(cts)
def reverse_and_grad(self, output, ct, params=(), **kwargs): rng = kwargs.pop('rng', None) rngs = (None, ) * self._n_layers if rng is not None: rngs = backend.random.split(rng, self._n_layers) # Forward pass through self.pre_attention, while preparing for # later backprop. # Note: jax.vjp does not allow us to use **kwargs in the signature here. def call_pre_attention(x, params, kwargs): return self.pre_attention(x, params, **kwargs) pre_attention_kwargs = kwargs.copy() pre_attention_kwargs['rng'] = rngs[0] stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output, params[0], pre_attention_kwargs) # Backprop through adding the residual assert len(ct) == 2 ct = saved_ct = (ct[0], ct[0], ct[1]) # Backprop through self.post_attention with respect to the inputs only call_post_attention_kwargs = kwargs.copy() call_post_attention_kwargs['rng'] = rngs[2] def call_post_attention(x): return self.post_attention(x, params[2], **call_post_attention_kwargs) # Note: these are *not* the actual inputs to self.post_attention. # If self.post_attention is not linear, we will get incorrect gradients. dummy_inputs = (stack[-3], stack[-2], stack[-1]) _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs) (ct, ) = post_attention_vjpfun(ct) # Simultaneous forward pass and backprop through the attention mechanism attention_kwargs = kwargs.copy() attention_kwargs['rng'] = rngs[1] stack, ct = self.attention.forward_and_vjp(stack, ct, **attention_kwargs) attention_params_ct = () # Backprop through self.pre_attention (x_ct, pre_attention_params_ct, pre_attention_kwargs_ct) = pre_attention_vjpfun(ct) # Forward pass for self.post_attention, and backprop with respect to the # parameters only def call_post_attention2(params, kwargs): return self.post_attention(stack, params, **kwargs) stack, post_attention_vjpfun = jax.vjp(call_post_attention2, params[2], call_post_attention_kwargs) (post_attention_params_ct, post_attention_kwargs_ct) = post_attention_vjpfun(saved_ct) # Forward pass through subtracting the residual reconstructed_x = self.subtract_top(stack, params[-1], rng=rngs[-1], **kwargs) params_ct = ( pre_attention_params_ct, attention_params_ct, post_attention_params_ct, (), ) # We don't actually backprop through the kwargs, but the API requires that # we provide a value for kwargs_ct. kwargs_ct = pre_attention_kwargs_ct del post_attention_kwargs_ct return reconstructed_x, (x_ct, params_ct, kwargs_ct)
import jax.numpy as jnp import simsoptpp as sopp from .._core.optimizable import Optimizable @jit def incremental_arclength_pure(d1gamma): """ This function is used in a Python+Jax implementation of the curve arc length formula. """ return jnp.linalg.norm(d1gamma, axis=1) incremental_arclength_vjp = jit(lambda d1gamma, v: vjp(lambda d1g: incremental_arclength_pure(d1g), d1gamma)[1](v)[0]) @jit def kappa_pure(d1gamma, d2gamma): """ This function is used in a Python+Jax implementation of formula for curvature. """ return jnp.linalg.norm(jnp.cross(d1gamma, d2gamma), axis=1)/jnp.linalg.norm(d1gamma, axis=1)**3 kappavjp0 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d1g: kappa_pure(d1g, d2gamma), d1gamma)[1](v)[0]) kappavjp1 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d2g: kappa_pure(d1gamma, d2g), d2gamma)[1](v)[0]) kappagrad0 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d1g: kappa_pure(d1g, d2gamma))(d1gamma)) kappagrad1 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d2g: kappa_pure(d1gamma, d2g))(d2gamma))
def value_and_grad(fn, args): """Given `fn: (args) -> out, extra`, returns `dout/dargs`.""" output, vjp_fn, extra = jax.vjp(fn, args, has_aux=True) grad = vjp_fn(np.ones_like(output))[0] return output, extra, grad
"x": 2.2, "y": [1.0, 2.2] }], [], [{ "x": 3.3, "y": [1.0, 2.0, 3.0] }]]) tangent = ak.Array([[{ "x": 0.0, "y": [1.0] }, { "x": 2.0, "y": [1.5, 0.0] }], [], [{ "x": 1.5, "y": [2.0, 0.5, 1.0] }]]) primal_result, tangent_result = jax.jvp(func, (primal, ), (tangent, )) print("resulting types", type(primal_result), type(tangent_result)) print(primal_result) print(tangent_result) jit_result = jax.jit(func)(primal) print("resulting type", type(jit_result)) print(jit_result) val, func = jax.vjp(func, primal) print(func(val))
def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): rng = kwargs.pop('rng', None) rngs = (None, ) * self._n_layers if rng is not None: rngs = backend.random.split(rng, self._n_layers) # Forward pass through self.pre_attention, while preparing for # later backprop. def call_pre_attention(x, params): res = self.pre_attention(x, params=params, state=state[0], rng=rngs[0], **kwargs) return res stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output, params[0]) # Backprop through adding the residual assert len(ct) == 2 ct = saved_ct = (ct[0], ct[0], ct[1]) # Backprop through self.post_attention with respect to the inputs only def call_post_attention(x): res = self.post_attention(x, params=params[2], state=state[2], rng=rngs[2], **kwargs) return res # Note: these are *not* the actual inputs to self.post_attention. # If self.post_attention is not linear, we will get incorrect gradients. dummy_inputs = (stack[-3], stack[-2], stack[-1]) _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs) (ct, ) = post_attention_vjpfun(ct) # Simultaneous forward pass and backprop through the attention mechanism stack, ct = self.attention.forward_and_backward(stack, ct, rng=rngs[1], **kwargs) assert not jax.tree_util.tree_leaves(params[1]) attention_params_ct = params[1] # This is valid when params is empty. # Backprop through self.pre_attention x_ct, pre_attention_params_ct = pre_attention_vjpfun(ct) # Forward pass for self.post_attention, and backprop with respect to the # parameters only def call_post_attention2(params): res = self.post_attention(stack, params=params, state=state[2], rng=rngs[2], **kwargs) return res stack, post_attention_vjpfun = jax.vjp(call_post_attention2, params[2]) (post_attention_params_ct, ) = post_attention_vjpfun(saved_ct) # Forward pass through subtracting the residual reconstructed_x = self.subtract_top(stack, params=params[-1], state=state[-1], rng=rngs[-1], **kwargs) assert not jax.tree_util.tree_leaves(params[-1]) add_top_params_ct = params[-1] params_ct = [ pre_attention_params_ct, attention_params_ct, post_attention_params_ct, add_top_params_ct, ] return reconstructed_x, (x_ct, params_ct)
def augmented_dynamics(augmented_state, t, flat_args): # Orginal system augmented with vjp_y, vjp_t and vjp_args. y, adjoint, _, _ = unpack(augmented_state) dy_dt, vjp_all = vjp(flat_func, y, t, flat_args) vjp_a, vjp_t, vjp_args = vjp_all(-adjoint) return np.concatenate([dy_dt, vjp_a, vjp_t.reshape(1), vjp_args])
def _jac_diag(inp): outp, vjp_fun = jax.vjp(lambda x: _bijector(x, cond)[0], inp) return vjp_fun(jnp.ones_like(outp))[0]
def augmented_dynamics(augmented_state, t): y, adjoint = unpack(augmented_state) dy_dt, vjp_all = vjp(func, y, t) vjp_a, vjp_t = vjp_all(adjoint) return np.concatenate([dy_dt, vjp_a])
def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name so, vjpfun = jax.vjp(binned_attn, sqk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct
def _test_transformation(self, func, param, msg=None): out, f_vjp = jax.vjp(func, param) cotangent, = f_vjp(np.ones_like(out).astype(out.dtype)) self.assertEqual(param.shape, cotangent.shape) if not FLAGS.execute_only: self.assertNotAllEqual(cotangent, np.zeros_like(cotangent), msg=msg)
def expected_f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(z), y) return vjp_f(p)
def full_vjp_func(func_args): # Trace the tagged function typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args) jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals layer_tags, loss_tags = extract_tags(jaxpr) layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0] layer_input_vars = tuple(set(layer_vars_flat)) def forward(): own_func_args = func_args # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) write = functools.partial(tgm.write_env, env) # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` num_losses_passed = 0 for eqn in jaxpr.eqns: tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) if isinstance(eqn.primitive, tags.LossTag): num_losses_passed += 1 if num_losses_passed == len(loss_tags): break if num_losses_passed != len(loss_tags): raise ValueError("This should be unreachable.") return jax_util.safe_map(read, layer_input_vars) def forward_aux(aux): own_func_args = func_args # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) def write(var, val): if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)): val = val + aux[var] if var in aux else val env[var] = val # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` num_losses_passed = 0 losses_inputs_values = [] losses_kwargs_values = [] for eqn in jaxpr.eqns: input_values = jax_util.safe_map(read, eqn.invars) tgm.evaluate_eqn(eqn, input_values, write) if isinstance(eqn.primitive, tags.LossTag): loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"]) losses_inputs_values.append(loss.inputs) losses_kwargs_values.append( dict(targets=loss.targets, weight=eqn.params["weight"])) num_losses_passed += 1 if num_losses_passed == len(loss_tags): break if num_losses_passed != len(loss_tags): raise ValueError("This should be unreachable.") # Read the inputs to the loss functions, but also return the target values return tuple(losses_inputs_values), tuple(losses_kwargs_values) layer_input_values = forward() primals_dict = dict(zip(layer_input_vars, layer_input_values)) primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0])) aux_values = jax.tree_map(jnp.zeros_like, layer_input_values) aux_dict = dict(zip(layer_input_vars, aux_values)) losses_args, aux_vjp, losses_kwargs = jax.vjp(forward_aux, aux_dict, has_aux=True) losses = tuple( tag.primitive.loss(*inputs, **kwargs) for tag, inputs, kwargs in zip(loss_tags, losses_args, losses_kwargs)) def vjp_func(tangents): all_tangents = aux_vjp(tangents) tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] inputs_tangents = jax.tree_flatten(inputs_tangents)[0] tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) read_primals = functools.partial(tgm.read_env, primals_dict) read_tangents = functools.partial(tgm.read_env, tangents_dict) layers_info = [] for jaxpr_eqn in layer_tags: layer_tag = _unbox_layer_tag(jaxpr_eqn) info = dict() primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) ( info["outputs"], info["inputs"], info["params"], ) = layer_tag.split_all_inputs(primals) tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) ( info["outputs_tangent"], info["inputs_tangent"], info["params_tangent"], ) = layer_tag.split_all_inputs(tangents) layers_info.append(info) return tuple(layers_info) return losses, vjp_func