def test_proper_torsion():

    # proper torsions have a variadic number of terms

    patterns = [
        ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]],
        ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]],
        ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]],
        ['[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]]],
        ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]],
    ]

    smirks = [x[0] for x in patterns]
    params = [x[1] for x in patterns]
    props = None

    pth = bonded.ProperTorsionHandler(smirks, params, props)

    mol = Chem.MolFromSmiles("FC(Br)=C(Br)F")

    torsion_params, torsion_idxs = pth.parameterize(mol)

    assert torsion_idxs.shape == (8, 4)
    assert torsion_params.shape == (8, 3)

    torsion_params_new, torsion_vjp_fn, torsion_idxs_new = jax.vjp(functools.partial(pth.partial_parameterize, mol=mol), pth.params, has_aux=True)

    np.testing.assert_array_equal(torsion_params_new, torsion_params)
    np.testing.assert_array_equal(torsion_idxs_new, torsion_idxs)

    torsion_param_adjoints = np.random.randn(*torsion_params.shape)

    ff_adjoints = torsion_vjp_fn(torsion_param_adjoints)[0]

    mask = np.argwhere(torsion_params > 90)
    assert np.all(ff_adjoints[mask] == 0.0) == True
Exemple #2
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = (None, ) * self._n_layers
        if rng is not None:
            rngs = random.split(rng, self._n_layers)

        def call_compute_residual(x, weights):
            res = self.compute_residual(x,
                                        weights=weights,
                                        state=state[0],
                                        rng=rngs[0],
                                        **kwargs)
            return res

        assert len(ct) == 2
        ct = ((ct[0], ct[0], ct[1]))

        stack_with_residual, vjpfun = jax.vjp(call_compute_residual, output,
                                              weights[0])
        reconstructed_x = self.subtract_top(stack_with_residual,
                                            weights=weights[-1],
                                            state=state[-1],
                                            rng=rngs[-1],
                                            **kwargs)

        x_ct, residual_weights_ct = vjpfun(ct)
        assert not jax.tree_util.tree_leaves(weights[-1])
        add_top_weights_ct = weights[-1]
        return reconstructed_x, (x_ct,
                                 [residual_weights_ct, add_top_weights_ct])
Exemple #3
0
    def parameterize(self, mol):
        """
        Parameters
        ----------

        mol: Chem.ROMol
            molecule to be parameterized.

        Returns
        -------
        tuple
            (parameters of shape [N,2], vjp_fn)

        """
        param_idxs = generate_nonbonded_idxs(mol, self.smirks)

        def param_fn(p):
            params = p[param_idxs]
            sigmas = params[:, 0]
            epsilons = jnp.power(params[:, 1],
                                 2)  # resolves a super annoying singularity
            return jnp.stack([sigmas, epsilons], axis=1)

        return jax.vjp(param_fn, self.params)
Exemple #4
0
    def parameterize(self, mol):
        """
        Parameterize given molecule

        Parameters
        ----------
        mol: Chem.ROMol
            rdkit molecule, should have hydrogens pre-added

        Returns
        -------
        tuple of (Q,2) (np.int32), ((Q,2), fn: R^Qx2 -> R^Px2))
            System bond idxes, parameters, and the vjp_fn.

        """

        bond_idxs, param_idxs = generate_vd_idxs(mol, self.smirks)
        param_fn = functools.partial(parameterize_ligand,
                                     param_idxs=param_idxs)
        sys_params, vjp_fn = jax.vjp(param_fn, self.params)

        return np.array(bond_idxs,
                        dtype=np.int32), (np.array(sys_params,
                                                   dtype=np.float64), vjp_fn)
Exemple #5
0
    def critic_loss(
        critic_params: hk.Params,
        generator_params: hk.Params,
        img_batch: ImgBatch,
        latent_batch: LatentBatch,
    ) -> Tuple[jnp.ndarray, Log]:

        batch_size = img_batch.shape[0]
        img_generated = generator.apply(generator_params, latent_batch)

        def f(x):
            return critic.apply(critic_params, x)

        f_real = f(img_batch)

        # Use the vector-jacobian product to efficiently compute the grad for each f_i.
        f_gen, grad_fn = jax.vjp(f, img_generated)
        grad = grad_fn(jnp.ones(f_gen.shape))[0]

        assert_shape(f_real, (batch_size, 1))
        assert_shape(f_gen, (batch_size, 1))
        assert_shape(grad, img_batch.shape)

        flat_grad = grad.reshape(batch_size, -1)
        gp = jnp.square(1 - jnp.linalg.norm(flat_grad, axis=1))
        assert_shape(gp, (batch_size, ))

        loss = jnp.mean(f_gen) - jnp.mean(f_real)
        gradient_penalty = jnp.mean(gp)

        log = {
            "wasserstein": -loss,
            "gradient_penalty": gradient_penalty,
        }

        return loss + 10 * gradient_penalty, log
Exemple #6
0
 def f(x):
     y = x + 1
     p, vjp_f = vjp(
         lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y)
     return vjp_f(p)
Exemple #7
0
 def _helper(cond):
     outp, _ = jax.vjp(lambda x: _bijector(root, x)[0], cond)
     _, vjp_fun = jax.vjp(lambda x: _bijector(x, cond)[0], root)
     jac = vjp_fun(jnp.ones_like(outp))[0]
     return (outp, outp, jac)
Exemple #8
0
 def _call(x, *params):
     y, vjp_bijector = jax.vjp(lambda x: bijector(x, *params), x)
     ldj = jnp.log(vjp_bijector(jnp.ones_like(y))[0])
     return y, ldj
def uoro_grad(key, theta, x, t0, T, K, s_tilde=None, theta_tilde=None):
    epsilon_perturbation = 1e-7
    epsilon_stability = 1e-7

    t_current = t0
    mystate = x
    total_theta_grad = 0
    total_loss = 0.0

    if s_tilde is None:
        s_tilde = jnp.zeros(mystate.shape)

    if theta_tilde is None:
        theta_tilde = jnp.zeros(theta.shape)

    for i in range(K):
        state_vec_old = mystate

        state_new = single_step(theta, mystate, t_current, T)
        total_loss += loss(state_new) * (t_current < T)

        state_vec_new = state_new

        dl_dstate_old = compute_dL_dstate_old(theta, state_vec_old, t_current,
                                              T)
        dl_dtheta_direct = compute_dL_dtheta_direct(theta, state_vec_old,
                                                    t_current, T)

        indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde
        pseudograds = indirect_grad + dl_dtheta_direct

        state_old_perturbed = state_vec_old + s_tilde * epsilon_perturbation
        state_vec_new_perturbed = single_step(theta, state_old_perturbed,
                                              t_current, T)

        state_deriv_in_direction_s_tilde = (
            state_vec_new_perturbed - state_vec_new) / epsilon_perturbation

        key, skey = jax.random.split(key)
        nus = jnp.round(jax.random.uniform(skey, state_vec_old.shape)) * 2 - 1

        custom_f = lambda param_vector: single_step(
            param_vector, state_vec_old, t_current, T)
        primals, f_vjp = jax.vjp(custom_f, theta)
        direct_theta_tilde_contribution, = f_vjp(nus)

        rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) /
                         (jnp.linalg.norm(state_deriv_in_direction_s_tilde) +
                          epsilon_stability))
        rho_1 = jnp.sqrt(
            (jnp.linalg.norm(direct_theta_tilde_contribution) +
             epsilon_stability) / (jnp.linalg.norm(nus) + epsilon_stability))

        theta_grad = pseudograds
        total_theta_grad += theta_grad

        s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus
        theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1

        mystate = state_new
        t_current += 1

    return (key, total_loss, mystate, s_tilde, theta_tilde), total_theta_grad
        def body_fun(vals):
            """Performs attention for a single batch element and head."""
            batch_loop_idx = vals[0]
            if self._prng is None:
                hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
                hash_rng, slice_rng = backend.random.split(hash_slice_rng)
            else:
                # TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
                hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
                slice_rng = jax.random.fold_in(rng, batch_loop_idx)
            qk_slice = jax.lax.dynamic_index_in_dim(qk,
                                                    batch_loop_idx,
                                                    axis=0,
                                                    keepdims=False)
            v_slice = jax.lax.dynamic_index_in_dim(v,
                                                   batch_loop_idx,
                                                   axis=0,
                                                   keepdims=False)

            if buckets is None:
                buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
            else:
                buckets_slice = jax.lax.dynamic_index_in_dim(buckets,
                                                             batch_loop_idx,
                                                             axis=0,
                                                             keepdims=False)

            if ct is None:
                out_slice = self.single_call(qk_slice,
                                             v_slice,
                                             buckets_slice,
                                             rng=slice_rng)
            else:

                def _do_single_call(qk_slice, v_slice):
                    return self.single_call(qk_slice,
                                            v_slice,
                                            buckets_slice,
                                            rng=slice_rng)

                ct_slice = jax.lax.dynamic_index_in_dim(ct,
                                                        batch_loop_idx,
                                                        axis=0,
                                                        keepdims=False)
                out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
                qk_ct_slice, v_ct_slice = vjpfun(ct_slice)

            new_vals = (batch_loop_idx + 1, )
            if return_output:
                out_accum = vals[1]
                out_accum = jax.lax.dynamic_update_index_in_dim(out_accum,
                                                                out_slice,
                                                                batch_loop_idx,
                                                                axis=0)
                new_vals = new_vals + (out_accum, )
            if return_state:
                buckets_accum = vals[2]
                buckets_accum = jax.lax.dynamic_update_index_in_dim(
                    buckets_accum, buckets_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (buckets_accum, )
            if ct is not None:
                qk_ct_accum, v_ct_accum = vals[-2:]
                qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
                v_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (qk_ct_accum, v_ct_accum)

            return new_vals
 def binned_attn_vjp(sq, sk, sv, so_ct):
     so, vjpfun = jax.vjp(binned_attn, sq, sk, sv)
     sqkv_ct = vjpfun(so_ct)
     return so, sqkv_ct
Exemple #12
0
    def parameterize(self, mol):
        """
        Parameters
        ----------

        mol: Chem.ROMol
            molecule to be parameterized.

        """
        # imported here for optional dependency
        from openeye import oechem
        from openeye import oequacpac

        mb = Chem.MolToMolBlock(mol)
        ims = oechem.oemolistream()
        ims.SetFormat(oechem.OEFormat_SDF)
        ims.openstring(mb)

        for buf_mol in ims.GetOEMols():
            oemol = oechem.OEMol(buf_mol)
            # AromaticityModel.assign(oe_molecule, bcc_collection.aromaticity_model)
            AromaticityModel.assign(oemol)

        # check for cache
        cache_key = 'AM1Cache'
        if not mol.HasProp(cache_key):
            result = oequacpac.OEAssignCharges(
                oemol, oequacpac.OEAM1Charges(symmetrize=True))
            if result is False:
                raise Exception('Unable to assign charges')

            am1_charges = []
            for index, atom in enumerate(oemol.GetAtoms()):
                q = atom.GetPartialCharge() * np.sqrt(constants.ONE_4PI_EPS0)
                am1_charges.append(q)

            mol.SetProp(cache_key, base64.b64encode(pickle.dumps(am1_charges)))

        else:
            am1_charges = pickle.loads(base64.b64decode(
                mol.GetProp(cache_key)))

        bond_idxs = []
        bond_idx_params = []

        for index in range(len(self.smirks)):
            smirk = self.smirks[index]
            param = self.params[index]

            substructure_search = oechem.OESubSearch(smirk)
            substructure_search.SetMaxMatches(0)

            matched_bonds = []
            matches = []
            for match in substructure_search.Match(oemol):

                matched_indices = {
                    atom_match.pattern.GetMapIdx() - 1:
                    atom_match.target.GetIdx()
                    for atom_match in match.GetAtoms()
                    if atom_match.pattern.GetMapIdx() != 0
                }

                matches.append(matched_indices)

            for matched_indices in matches:

                forward_matched_bond = [matched_indices[0], matched_indices[1]]
                reverse_matched_bond = [matched_indices[1], matched_indices[0]]

                if (forward_matched_bond in matched_bonds
                        or reverse_matched_bond in matched_bonds
                        or forward_matched_bond in bond_idxs
                        or reverse_matched_bond in bond_idxs):
                    continue

                matched_bonds.append(forward_matched_bond)
                bond_idxs.append(forward_matched_bond)
                bond_idx_params.append(index)

        bcc_fn = functools.partial(apply_bcc,
                                   bond_idxs=np.array(bond_idxs),
                                   bond_idx_params=np.array(bond_idx_params,
                                                            dtype=np.int32),
                                   am1_charges=np.array(am1_charges))

        charges, vjp_fn = jax.vjp(bcc_fn, self.params)

        return np.array(charges, dtype=np.float64), vjp_fn
def cost_func_vjp(bb, u):
    _, jax_vjp = jax.vjp(cost_func, bb, u)
    directmat = jax_vjp(1.0)
    return directmat[0]
# Apply network to dummy inputs
inputs = np.zeros((1, 32))
# predictions = net_apply(net_params, inputs)
# print ("pred: ", predictions)


def net_apply_reverse(inputs, net_params):
    return net_apply(net_params, inputs)


@jit
def test_loss(net_params, inputs):
    return np.sum(net_apply(net_params, inputs))


primals_out, vjpfun = vjp(partial(net_apply_reverse, inputs), net_params)
print(primals_out)

primals_out, jvpfun = linearize(partial(net_apply_reverse, inputs), net_params)
# primals_out, vp = jvp(net_apply, (net_params, inputs), random.normal(rng, (1, 256)))
print(primals_out)
input("")

for i in range(10):
    import time
    s = time.time()
    out = vjpfun(random.normal(rng, (1, 10)))
    e = time.time()
    print("vjp time: ", (e - s))

    s = time.time()
Exemple #15
0
 def rmatvec(v):
     _, vjp_fn = vjp(self.c, z)
     return vjp_fn(v)[0]
Exemple #16
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rngs = _pop_rng_and_split(kwargs, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = _inputs_from_stack(self.compute_residual, stack)
    outputs, compute_residual_vjpfun, outputs_aux = jax.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = _inputs_from_stack(self.attention_layer, stack)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self.attention_layer.forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = _outputs_onto_stack(
          self.attention_layer, attn_inputs_ct, stack_ct,
          self.attention_layer.n_out, self.attention_layer.n_in)

    compute_residual_ct = _inputs_from_stack(
        self.compute_residual, stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = _outputs_onto_stack(
        self.compute_residual, compute_residual_inputs_ct, stack_ct,
        self.compute_residual.n_out, self.compute_residual.n_in)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + jax.tree_multimap(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)
 def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                           ct_slice):
     output_slice, vjpfun = jax.vjp(forward_slice, query_slice,
                                    q_loop_idx, key, value)
     return output_slice, vjpfun(ct_slice)
Exemple #18
0
 def f_vjp(args, cts):
     res, pullback = jax.vjp(f, *args)
     return pullback(cts)
    def reverse_and_grad(self, output, ct, params=(), **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = (None, ) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)

        # Forward pass through self.pre_attention, while preparing for
        # later backprop.
        # Note: jax.vjp does not allow us to use **kwargs in the signature here.
        def call_pre_attention(x, params, kwargs):
            return self.pre_attention(x, params, **kwargs)

        pre_attention_kwargs = kwargs.copy()
        pre_attention_kwargs['rng'] = rngs[0]
        stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output,
                                              params[0], pre_attention_kwargs)

        # Backprop through adding the residual
        assert len(ct) == 2
        ct = saved_ct = (ct[0], ct[0], ct[1])

        # Backprop through self.post_attention with respect to the inputs only
        call_post_attention_kwargs = kwargs.copy()
        call_post_attention_kwargs['rng'] = rngs[2]

        def call_post_attention(x):
            return self.post_attention(x, params[2],
                                       **call_post_attention_kwargs)

        # Note: these are *not* the actual inputs to self.post_attention.
        # If self.post_attention is not linear, we will get incorrect gradients.
        dummy_inputs = (stack[-3], stack[-2], stack[-1])
        _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs)
        (ct, ) = post_attention_vjpfun(ct)

        # Simultaneous forward pass and backprop through the attention mechanism
        attention_kwargs = kwargs.copy()
        attention_kwargs['rng'] = rngs[1]
        stack, ct = self.attention.forward_and_vjp(stack, ct,
                                                   **attention_kwargs)
        attention_params_ct = ()

        # Backprop through self.pre_attention
        (x_ct, pre_attention_params_ct,
         pre_attention_kwargs_ct) = pre_attention_vjpfun(ct)

        # Forward pass for self.post_attention, and backprop with respect to the
        # parameters only
        def call_post_attention2(params, kwargs):
            return self.post_attention(stack, params, **kwargs)

        stack, post_attention_vjpfun = jax.vjp(call_post_attention2, params[2],
                                               call_post_attention_kwargs)
        (post_attention_params_ct,
         post_attention_kwargs_ct) = post_attention_vjpfun(saved_ct)

        # Forward pass through subtracting the residual
        reconstructed_x = self.subtract_top(stack,
                                            params[-1],
                                            rng=rngs[-1],
                                            **kwargs)

        params_ct = (
            pre_attention_params_ct,
            attention_params_ct,
            post_attention_params_ct,
            (),
        )

        # We don't actually backprop through the kwargs, but the API requires that
        # we provide a value for kwargs_ct.
        kwargs_ct = pre_attention_kwargs_ct
        del post_attention_kwargs_ct

        return reconstructed_x, (x_ct, params_ct, kwargs_ct)
Exemple #20
0
import jax.numpy as jnp

import simsoptpp as sopp
from .._core.optimizable import Optimizable


@jit
def incremental_arclength_pure(d1gamma):
    """
    This function is used in a Python+Jax implementation of the curve arc length formula.
    """

    return jnp.linalg.norm(d1gamma, axis=1)


incremental_arclength_vjp = jit(lambda d1gamma, v: vjp(lambda d1g: incremental_arclength_pure(d1g), d1gamma)[1](v)[0])


@jit
def kappa_pure(d1gamma, d2gamma):
    """
    This function is used in a Python+Jax implementation of formula for curvature.
    """

    return jnp.linalg.norm(jnp.cross(d1gamma, d2gamma), axis=1)/jnp.linalg.norm(d1gamma, axis=1)**3


kappavjp0 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d1g: kappa_pure(d1g, d2gamma), d1gamma)[1](v)[0])
kappavjp1 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d2g: kappa_pure(d1gamma, d2g), d2gamma)[1](v)[0])
kappagrad0 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d1g: kappa_pure(d1g, d2gamma))(d1gamma))
kappagrad1 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d2g: kappa_pure(d1gamma, d2g))(d2gamma))
Exemple #21
0
def value_and_grad(fn, args):
  """Given `fn: (args) -> out, extra`, returns `dout/dargs`."""
  output, vjp_fn, extra = jax.vjp(fn, args, has_aux=True)
  grad = vjp_fn(np.ones_like(output))[0]
  return output, extra, grad
Exemple #22
0
    "x": 2.2,
    "y": [1.0, 2.2]
}], [], [{
    "x": 3.3,
    "y": [1.0, 2.0, 3.0]
}]])

tangent = ak.Array([[{
    "x": 0.0,
    "y": [1.0]
}, {
    "x": 2.0,
    "y": [1.5, 0.0]
}], [], [{
    "x": 1.5,
    "y": [2.0, 0.5, 1.0]
}]])

primal_result, tangent_result = jax.jvp(func, (primal, ), (tangent, ))
print("resulting types", type(primal_result), type(tangent_result))
print(primal_result)
print(tangent_result)

jit_result = jax.jit(func)(primal)
print("resulting type", type(jit_result))
print(jit_result)

val, func = jax.vjp(func, primal)

print(func(val))
Exemple #23
0
    def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = (None, ) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)

        # Forward pass through self.pre_attention, while preparing for
        # later backprop.
        def call_pre_attention(x, params):
            res = self.pre_attention(x,
                                     params=params,
                                     state=state[0],
                                     rng=rngs[0],
                                     **kwargs)
            return res

        stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output,
                                              params[0])

        # Backprop through adding the residual
        assert len(ct) == 2
        ct = saved_ct = (ct[0], ct[0], ct[1])

        # Backprop through self.post_attention with respect to the inputs only
        def call_post_attention(x):
            res = self.post_attention(x,
                                      params=params[2],
                                      state=state[2],
                                      rng=rngs[2],
                                      **kwargs)
            return res

        # Note: these are *not* the actual inputs to self.post_attention.
        # If self.post_attention is not linear, we will get incorrect gradients.
        dummy_inputs = (stack[-3], stack[-2], stack[-1])
        _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs)
        (ct, ) = post_attention_vjpfun(ct)

        # Simultaneous forward pass and backprop through the attention mechanism
        stack, ct = self.attention.forward_and_backward(stack,
                                                        ct,
                                                        rng=rngs[1],
                                                        **kwargs)
        assert not jax.tree_util.tree_leaves(params[1])
        attention_params_ct = params[1]  # This is valid when params is empty.

        # Backprop through self.pre_attention
        x_ct, pre_attention_params_ct = pre_attention_vjpfun(ct)

        # Forward pass for self.post_attention, and backprop with respect to the
        # parameters only
        def call_post_attention2(params):
            res = self.post_attention(stack,
                                      params=params,
                                      state=state[2],
                                      rng=rngs[2],
                                      **kwargs)
            return res

        stack, post_attention_vjpfun = jax.vjp(call_post_attention2, params[2])
        (post_attention_params_ct, ) = post_attention_vjpfun(saved_ct)

        # Forward pass through subtracting the residual
        reconstructed_x = self.subtract_top(stack,
                                            params=params[-1],
                                            state=state[-1],
                                            rng=rngs[-1],
                                            **kwargs)

        assert not jax.tree_util.tree_leaves(params[-1])
        add_top_params_ct = params[-1]
        params_ct = [
            pre_attention_params_ct,
            attention_params_ct,
            post_attention_params_ct,
            add_top_params_ct,
        ]

        return reconstructed_x, (x_ct, params_ct)
Exemple #24
0
 def augmented_dynamics(augmented_state, t, flat_args):
     # Orginal system augmented with vjp_y, vjp_t and vjp_args.
     y, adjoint, _, _ = unpack(augmented_state)
     dy_dt, vjp_all = vjp(flat_func, y, t, flat_args)
     vjp_a, vjp_t, vjp_args = vjp_all(-adjoint)
     return np.concatenate([dy_dt, vjp_a, vjp_t.reshape(1), vjp_args])
Exemple #25
0
 def _jac_diag(inp):
     outp, vjp_fun = jax.vjp(lambda x: _bijector(x, cond)[0], inp)
     return vjp_fun(jnp.ones_like(outp))[0]
Exemple #26
0
 def augmented_dynamics(augmented_state, t):
     y, adjoint = unpack(augmented_state)
     dy_dt, vjp_all = vjp(func, y, t)
     vjp_a, vjp_t = vjp_all(adjoint)
     return np.concatenate([dy_dt, vjp_a])
Exemple #27
0
 def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name
     so, vjpfun = jax.vjp(binned_attn, sqk, sv)
     sqkv_ct = vjpfun(so_ct)
     return so, sqkv_ct
Exemple #28
0
 def _test_transformation(self, func, param, msg=None):
   out, f_vjp = jax.vjp(func, param)
   cotangent, = f_vjp(np.ones_like(out).astype(out.dtype))
   self.assertEqual(param.shape, cotangent.shape)
   if not FLAGS.execute_only:
     self.assertNotAllEqual(cotangent, np.zeros_like(cotangent), msg=msg)
Exemple #29
0
 def expected_f(x):
     y = x + 1
     p, vjp_f = vjp(lambda z: jnp.sin(z), y)
     return vjp_f(p)
    def full_vjp_func(func_args):
        # Trace the tagged function
        typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args)
        jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
        layer_tags, loss_tags = extract_tags(jaxpr)

        layer_vars_flat = jax.tree_flatten([tag.invars
                                            for tag in layer_tags])[0]
        layer_input_vars = tuple(set(layer_vars_flat))

        def forward():
            own_func_args = func_args
            # Mapping from variable -> value
            env = dict()
            read = functools.partial(tgm.read_env, env)
            write = functools.partial(tgm.write_env, env)

            # Bind args and consts to environment
            write(jax.core.unitvar, jax.core.unit)
            jax_util.safe_map(write, jaxpr.invars,
                              jax.tree_flatten(own_func_args)[0])
            jax_util.safe_map(write, jaxpr.constvars, consts)

            # Loop through equations and evaluate primitives using `bind`
            num_losses_passed = 0
            for eqn in jaxpr.eqns:
                tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars),
                                 write)
                if isinstance(eqn.primitive, tags.LossTag):
                    num_losses_passed += 1
                    if num_losses_passed == len(loss_tags):
                        break
            if num_losses_passed != len(loss_tags):
                raise ValueError("This should be unreachable.")

            return jax_util.safe_map(read, layer_input_vars)

        def forward_aux(aux):
            own_func_args = func_args
            # Mapping from variable -> value
            env = dict()
            read = functools.partial(tgm.read_env, env)

            def write(var, val):
                if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
                    val = val + aux[var] if var in aux else val
                env[var] = val

            # Bind args and consts to environment
            write(jax.core.unitvar, jax.core.unit)
            jax_util.safe_map(write, jaxpr.invars,
                              jax.tree_flatten(own_func_args)[0])
            jax_util.safe_map(write, jaxpr.constvars, consts)

            # Loop through equations and evaluate primitives using `bind`
            num_losses_passed = 0
            losses_inputs_values = []
            losses_kwargs_values = []
            for eqn in jaxpr.eqns:
                input_values = jax_util.safe_map(read, eqn.invars)
                tgm.evaluate_eqn(eqn, input_values, write)
                if isinstance(eqn.primitive, tags.LossTag):
                    loss = eqn.primitive.loss(*input_values,
                                              weight=eqn.params["weight"])
                    losses_inputs_values.append(loss.inputs)
                    losses_kwargs_values.append(
                        dict(targets=loss.targets,
                             weight=eqn.params["weight"]))
                    num_losses_passed += 1
                    if num_losses_passed == len(loss_tags):
                        break
            if num_losses_passed != len(loss_tags):
                raise ValueError("This should be unreachable.")
            # Read the inputs to the loss functions, but also return the target values
            return tuple(losses_inputs_values), tuple(losses_kwargs_values)

        layer_input_values = forward()
        primals_dict = dict(zip(layer_input_vars, layer_input_values))
        primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0]))
        aux_values = jax.tree_map(jnp.zeros_like, layer_input_values)
        aux_dict = dict(zip(layer_input_vars, aux_values))

        losses_args, aux_vjp, losses_kwargs = jax.vjp(forward_aux,
                                                      aux_dict,
                                                      has_aux=True)
        losses = tuple(
            tag.primitive.loss(*inputs, **kwargs) for tag, inputs, kwargs in
            zip(loss_tags, losses_args, losses_kwargs))

        def vjp_func(tangents):
            all_tangents = aux_vjp(tangents)
            tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
            inputs_tangents = jax.tree_flatten(inputs_tangents)[0]
            tangents_dict.update(zip(jaxpr.invars, inputs_tangents))

            read_primals = functools.partial(tgm.read_env, primals_dict)
            read_tangents = functools.partial(tgm.read_env, tangents_dict)
            layers_info = []
            for jaxpr_eqn in layer_tags:
                layer_tag = _unbox_layer_tag(jaxpr_eqn)
                info = dict()
                primals = jax_util.safe_map(read_primals,
                                            tuple(jaxpr_eqn.invars))
                (
                    info["outputs"],
                    info["inputs"],
                    info["params"],
                ) = layer_tag.split_all_inputs(primals)
                tangents = jax_util.safe_map(read_tangents,
                                             tuple(jaxpr_eqn.invars))
                (
                    info["outputs_tangent"],
                    info["inputs_tangent"],
                    info["params_tangent"],
                ) = layer_tag.split_all_inputs(tangents)
                layers_info.append(info)
            return tuple(layers_info)

        return losses, vjp_func