Esempio n. 1
0
def get_alt_atom14(aatype, positions, mask):
    """Get alternative atom14 positions.

  Constructs renamed atom positions for ambiguous residues.

  Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree-
  rotation-symmetry"

  Args:
    aatype: Amino acid at given position
    positions: Atom positions as r3.Vecs in atom14 representation, (N, 14)
    mask: Atom masks in atom14 representation, (N, 14)
  Returns:
    renamed atom positions, renamed atom mask
  """
    # pick the transformation matrices for the given residue sequence
    # shape (num_res, 14, 14)
    renaming_transform = utils.batched_gather(jnp.asarray(RENAMING_MATRICES),
                                              aatype)

    positions = jax.tree_map(lambda x: x[:, :, None], positions)
    alternative_positions = jax.tree_map(lambda x: jnp.sum(x, axis=1),
                                         positions * renaming_transform)

    # Create the mask for the alternative ground truth (differs from the
    # ground truth mask, if only one of the atoms in an ambiguous pair has a
    # ground truth position)
    alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)

    return alternative_positions, alternative_mask
Esempio n. 2
0
def frames_and_literature_positions_to_atom14_pos(
        aatype: jnp.ndarray,  # (N)
        all_frames_to_global: r3.Rigids  # (N, 8)
) -> r3.Vecs:  # (N, 14)
    """Put atom literature positions (atom14 encoding) in each rigid group.

  Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11

  Args:
    aatype: aatype for each residue.
    all_frames_to_global: All per residue coordinate frames.
  Returns:
    Positions of all atom coordinates in global frame.
  """

    # Pick the appropriate transform for every atom.
    residx_to_group_idx = utils.batched_gather(
        residue_constants.restype_atom14_to_rigid_group, aatype)
    group_mask = jax.nn.one_hot(residx_to_group_idx,
                                num_classes=8)  # shape (N, 14, 8)

    # r3.Rigids with shape (N, 14)
    map_atoms_to_global = jax.tree_map(
        lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
        all_frames_to_global)

    # Gather the literature atom positions for each residue.
    # r3.Vecs with shape (N, 14)
    lit_positions = r3.vecs_from_tensor(
        utils.batched_gather(
            residue_constants.restype_atom14_rigid_group_positions, aatype))

    # Transform each atom from its local frame to the global frame.
    # r3.Vecs with shape (N, 14)
    pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions)

    # Mask out non-existing atoms.
    mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
    pred_positions = jax.tree_map(lambda x: x * mask, pred_positions)

    return pred_positions
Esempio n. 3
0
def atom37_to_atom14(
    atom37_data: jnp.ndarray,  # (N, 37, ...)
    batch: Dict[str, jnp.ndarray]) -> jnp.ndarray:  # (N, 14, ...)
    """Convert atom14 to atom37 representation."""
    assert len(atom37_data.shape) in [2, 3]
    assert 'residx_atom14_to_atom37' in batch
    assert 'atom14_atom_exists' in batch

    atom14_data = utils.batched_gather(atom37_data,
                                       batch['residx_atom14_to_atom37'],
                                       batch_dims=1)
    if len(atom37_data.shape) == 2:
        atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype)
    elif len(atom37_data.shape) == 3:
        atom14_data *= batch['atom14_atom_exists'][:, :, None].astype(
            atom14_data.dtype)
    return atom14_data
Esempio n. 4
0
def torsion_angles_to_frames(
    aatype: jnp.ndarray,  # (N)
    backb_to_global: r3.Rigids,  # (N)
    torsion_angles_sin_cos: jnp.ndarray  # (N, 7, 2)
) -> r3.Rigids:  # (N, 8)
    """Compute rigid group frames from torsion angles.

  Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10
  Jumper et al. (2021) Suppl. Alg. 25 "makeRotX"

  Args:
    aatype: aatype for each residue
    backb_to_global: Rigid transformations describing transformation from
      backbone frame to global frame.
    torsion_angles_sin_cos: sin and cosine of the 7 torsion angles
  Returns:
    Frames corresponding to all the Sidechain Rigid Transforms
  """
    assert len(aatype.shape) == 1
    assert len(backb_to_global.rot.xx.shape) == 1
    assert len(torsion_angles_sin_cos.shape) == 3
    assert torsion_angles_sin_cos.shape[1] == 7
    assert torsion_angles_sin_cos.shape[2] == 2

    # Gather the default frames for all rigid groups.
    # r3.Rigids with shape (N, 8)
    m = utils.batched_gather(
        residue_constants.restype_rigid_group_default_frame, aatype)
    default_frames = r3.rigids_from_tensor4x4(m)

    # Create the rotation matrices according to the given angles (each frame is
    # defined such that its rotation is around the x-axis).
    sin_angles = torsion_angles_sin_cos[..., 0]
    cos_angles = torsion_angles_sin_cos[..., 1]

    # insert zero rotation for backbone group.
    num_residues, = aatype.shape
    sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],
                                 axis=-1)
    cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],
                                 axis=-1)
    zeros = jnp.zeros_like(sin_angles)
    ones = jnp.ones_like(sin_angles)

    # all_rots are r3.Rots with shape (N, 8)
    all_rots = r3.Rots(ones, zeros, zeros, zeros, cos_angles, -sin_angles,
                       zeros, sin_angles, cos_angles)

    # Apply rotations to the frames.
    all_frames = r3.rigids_mul_rots(default_frames, all_rots)

    # chi2, chi3, and chi4 frames do not transform to the backbone frame but to
    # the previous frame. So chain them up accordingly.
    chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames)
    chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames)
    chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames)

    chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames)
    chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb,
                                               chi2_frame_to_frame)
    chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb,
                                               chi3_frame_to_frame)
    chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb,
                                               chi4_frame_to_frame)

    # Recombine them to a r3.Rigids with shape (N, 8).
    def _concat_frames(xall, x5, x6, x7):
        return jnp.concatenate(
            [xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1)

    all_frames_to_backb = jax.tree_map(_concat_frames, all_frames,
                                       chi2_frame_to_backb,
                                       chi3_frame_to_backb,
                                       chi4_frame_to_backb)

    # Create the global frames.
    # shape (N, 8)
    all_frames_to_global = r3.rigids_mul_rigids(
        jax.tree_map(lambda x: x[:, None], backb_to_global),
        all_frames_to_backb)

    return all_frames_to_global
Esempio n. 5
0
def atom37_to_torsion_angles(
    aatype: jnp.ndarray,  # (B, N)
    all_atom_pos: jnp.ndarray,  # (B, N, 37, 3)
    all_atom_mask: jnp.ndarray,  # (B, N, 37)
    placeholder_for_undefined=False,
) -> Dict[str, jnp.ndarray]:
    """Computes the 7 torsion angles (in sin, cos encoding) for each residue.

  The 7 torsion angles are in the order
  '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
  here pre_omega denotes the omega torsion angle between the given amino acid
  and the previous amino acid.

  Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_pos: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
    placeholder_for_undefined: flag denoting whether to set masked torsion
      angles to zero.
  Returns:
    Dict containing:
      * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
        2 dimensions denote sin and cos respectively
      * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
        with the angle shifted by pi for all chi angles affected by the naming
        ambiguities.
      * 'torsion_angles_mask': Mask for which chi angles are present.
  """

    # Map aatype > 20 to 'Unknown' (20).
    aatype = jnp.minimum(aatype, 20)

    # Compute the backbone angles.
    num_batch, num_res = aatype.shape

    pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32)
    prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]],
                                        axis=1)

    pad = jnp.zeros([num_batch, 1, 37], jnp.float32)
    prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]],
                                         axis=1)

    # For each torsion angle collect the 4 atom positions that define this angle.
    # shape (B, N, atoms=4, xyz=3)
    pre_omega_atom_pos = jnp.concatenate(
        [
            prev_all_atom_pos[:, :, 1:3, :],  # prev CA, C
            all_atom_pos[:, :, 0:2, :]  # this N, CA
        ],
        axis=-2)
    phi_atom_pos = jnp.concatenate(
        [
            prev_all_atom_pos[:, :, 2:3, :],  # prev C
            all_atom_pos[:, :, 0:3, :]  # this N, CA, C
        ],
        axis=-2)
    psi_atom_pos = jnp.concatenate(
        [
            all_atom_pos[:, :, 0:3, :],  # this N, CA, C
            all_atom_pos[:, :, 4:5, :]  # this O
        ],
        axis=-2)

    # Collect the masks from these atoms.
    # Shape [batch, num_res]
    pre_omega_mask = (
        jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1)  # prev CA, C
        * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1))  # this N, CA
    phi_mask = (
        prev_all_atom_mask[:, :, 2]  # prev C
        * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1))  # this N, CA, C
    psi_mask = (
        jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) *  # this N, CA, C
        all_atom_mask[:, :, 4])  # this O

    # Collect the atoms for the chi-angles.
    # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
    chi_atom_indices = get_chi_atom_indices()
    # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
    atom_indices = utils.batched_gather(params=chi_atom_indices,
                                        indices=aatype,
                                        axis=0,
                                        batch_dims=0)
    # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
    chis_atom_pos = utils.batched_gather(params=all_atom_pos,
                                         indices=atom_indices,
                                         axis=-2,
                                         batch_dims=2)

    # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
    chi_angles_mask = list(residue_constants.chi_angles_mask)
    chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
    chi_angles_mask = jnp.asarray(chi_angles_mask)

    # Compute the chi angle mask. I.e. which chis angles exist according to the
    # aatype. Shape [batch, num_res, chis=4].
    chis_mask = utils.batched_gather(params=chi_angles_mask,
                                     indices=aatype,
                                     axis=0,
                                     batch_dims=0)

    # Constrain the chis_mask to those chis, where the ground truth coordinates of
    # all defining four atoms are available.
    # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4].
    chi_angle_atoms_mask = utils.batched_gather(params=all_atom_mask,
                                                indices=atom_indices,
                                                axis=-1,
                                                batch_dims=2)
    # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
    chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
    chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)

    # Stack all torsion angle atom positions.
    # Shape (B, N, torsions=7, atoms=4, xyz=3)
    torsions_atom_pos = jnp.concatenate([
        pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :],
        psi_atom_pos[:, :, None, :, :], chis_atom_pos
    ],
                                        axis=2)

    # Stack up masks for all torsion angles.
    # shape (B, N, torsions=7)
    torsion_angles_mask = jnp.concatenate([
        pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None],
        chis_mask
    ],
                                          axis=2)

    # Create a frame from the first three atoms:
    # First atom: point on x-y-plane
    # Second atom: point on negative x-axis
    # Third atom: origin
    # r3.Rigids (B, N, torsions=7)
    torsion_frames = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :,
                                                                  1, :]),
        origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
        point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :,
                                                                0, :]))

    # Compute the position of the forth atom in this frame (y and z coordinate
    # define the chi angle)
    # r3.Vecs (B, N, torsions=7)
    forth_atom_rel_pos = r3.rigids_mul_vecs(
        r3.invert_rigids(torsion_frames),
        r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))

    # Normalize to have the sin and cos of the torsion angle.
    # jnp.ndarray (B, N, torsions=7, sincos=2)
    torsion_angles_sin_cos = jnp.stack(
        [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
    torsion_angles_sin_cos /= jnp.sqrt(
        jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) +
        1e-8)

    # Mirror psi, because we computed it from the Oxygen-atom.
    torsion_angles_sin_cos *= jnp.asarray([1., 1., -1., 1., 1., 1.,
                                           1.])[None, None, :, None]

    # Create alternative angles for ambiguous atom names.
    chi_is_ambiguous = utils.batched_gather(
        jnp.asarray(residue_constants.chi_pi_periodic), aatype)
    mirror_torsion_angles = jnp.concatenate(
        [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous],
        axis=-1)
    alt_torsion_angles_sin_cos = (torsion_angles_sin_cos *
                                  mirror_torsion_angles[:, :, :, None])

    if placeholder_for_undefined:
        # Add placeholder torsions in place of undefined torsion angles
        # (e.g. N-terminus pre-omega)
        placeholder_torsions = jnp.stack([
            jnp.ones(torsion_angles_sin_cos.shape[:-1]),
            jnp.zeros(torsion_angles_sin_cos.shape[:-1])
        ],
                                         axis=-1)
        torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
            ...,
            None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
        alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
            ...,
            None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])

    return {
        'torsion_angles_sin_cos': torsion_angles_sin_cos,  # (B, N, 7, 2)
        'alt_torsion_angles_sin_cos':
        alt_torsion_angles_sin_cos,  # (B, N, 7, 2)
        'torsion_angles_mask': torsion_angles_mask  # (B, N, 7)
    }
Esempio n. 6
0
def atom37_to_frames(
        aatype: jnp.ndarray,  # (...)
        all_atom_positions: jnp.ndarray,  # (..., 37, 3)
        all_atom_mask: jnp.ndarray,  # (..., 37)
) -> Dict[str, jnp.ndarray]:
    """Computes the frames for the up to 8 rigid groups for each residue.

  The rigid groups are defined by the possible torsions in a given amino acid.
  We group the atoms according to their dependence on the torsion angles into
  "rigid groups".  E.g., the position of atoms in the chi2-group depend on
  chi1 and chi2, but do not depend on chi3 or chi4.
  Jumper et al. (2021) Suppl. Table 2 and corresponding text.

  Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_positions: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
  Returns:
    Dictionary containing:
      * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
           represented as flat 12 dimensional array.
      * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
          the given frame are available in the ground truth, e.g. if they were
          resolved in the experiment.
      * 'rigidgroups_group_exists': Mask denoting whether given group is in
          principle present for given amino acid type.
      * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
          affected by naming ambiguity.
      * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
          corresponding to 'all_atom_positions' represented as flat
          12 dimensional array.
  """
    # 0: 'backbone group',
    # 1: 'pre-omega-group', (empty)
    # 2: 'phi-group', (currently empty, because it defines only hydrogens)
    # 3: 'psi-group',
    # 4,5,6,7: 'chi1,2,3,4-group'
    aatype_in_shape = aatype.shape

    # If there is a batch axis, just flatten it away, and reshape everything
    # back at the end of the function.
    aatype = jnp.reshape(aatype, [-1])
    all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3])
    all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])

    # Create an array with the atom names.
    # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)

    # 0: backbone frame
    restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']

    # 3: 'psi-group'
    restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']

    # 4,5,6,7: 'chi1,2,3,4-group'
    for restype, restype_letter in enumerate(residue_constants.restypes):
        resname = residue_constants.restype_1to3[restype_letter]
        for chi_idx in range(4):
            if residue_constants.chi_angles_mask[restype][chi_idx]:
                atom_names = residue_constants.chi_angles_atoms[resname][
                    chi_idx]
                restype_rigidgroup_base_atom_names[restype, chi_idx +
                                                   4, :] = atom_names[1:]

    # Create mask for existing rigid groups.
    restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_mask[:, 0] = 1
    restype_rigidgroup_mask[:, 3] = 1
    restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask

    # Translate atom names into atom37 indices.
    lookuptable = residue_constants.atom_order.copy()
    lookuptable[''] = 0
    restype_rigidgroup_base_atom37_idx = np.vectorize(
        lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names)

    # Compute the gather indices for all residues in the chain.
    # shape (N, 8, 3)
    residx_rigidgroup_base_atom37_idx = utils.batched_gather(
        restype_rigidgroup_base_atom37_idx, aatype)

    # Gather the base atom positions for each rigid group.
    base_atom_pos = utils.batched_gather(all_atom_positions,
                                         residx_rigidgroup_base_atom37_idx,
                                         batch_dims=1)

    # Compute the Rigids.
    gt_frames = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]),
        origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]),
        point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]))

    # Compute a mask whether the group exists.
    # (N, 8)
    group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype)

    # Compute a mask whether ground truth exists for the group
    gt_atoms_exist = utils.batched_gather(  # shape (N, 8, 3)
        all_atom_mask.astype(jnp.float32),
        residx_rigidgroup_base_atom37_idx,
        batch_dims=1)
    gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists  # (N, 8)

    # Adapt backbone frame to old convention (mirror x-axis and z-axis).
    rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
    rots[0, 0, 0] = -1
    rots[0, 2, 2] = -1
    gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots))

    # The frames for ambiguous rigid groups are just rotated by 180 degree around
    # the x-axis. The ambiguous group is always the last chi-group.
    restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32),
                                      [21, 8, 1, 1])

    for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
        restype = residue_constants.restype_order[
            residue_constants.restype_3to1[resname]]
        chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
        restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
        restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1

    # Gather the ambiguity information for each residue.
    residx_rigidgroup_is_ambiguous = utils.batched_gather(
        restype_rigidgroup_is_ambiguous, aatype)
    residx_rigidgroup_ambiguity_rot = utils.batched_gather(
        restype_rigidgroup_rots, aatype)

    # Create the alternative ground truth frames.
    alt_gt_frames = r3.rigids_mul_rots(
        gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot))

    gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames)
    alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames)

    # reshape back to original residue layout
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, ))
    group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, ))
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    residx_rigidgroup_is_ambiguous = jnp.reshape(
        residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, ))
    alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12,
                                       aatype_in_shape + (
                                           8,
                                           12,
                                       ))

    return {
        'rigidgroups_gt_frames': gt_frames_flat12,  # (..., 8, 12)
        'rigidgroups_gt_exists': gt_exists,  # (..., 8)
        'rigidgroups_group_exists': group_exists,  # (..., 8)
        'rigidgroups_group_is_ambiguous':
        residx_rigidgroup_is_ambiguous,  # (..., 8)
        'rigidgroups_alt_gt_frames': alt_gt_frames_flat12,  # (..., 8, 12)
    }
Esempio n. 7
0
def find_structural_violations(
        batch: Dict[str, jnp.ndarray],
        atom14_pred_positions: jnp.ndarray,  # (N, 14, 3)
        config: ml_collections.ConfigDict):
    """Computes several checks for structural violations."""

    # Compute between residue backbone violations of bonds and angles.
    connection_violations = all_atom.between_residue_bond_loss(
        pred_atom_positions=atom14_pred_positions,
        pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
        residue_index=batch['residue_index'].astype(jnp.float32),
        aatype=batch['aatype'],
        tolerance_factor_soft=config.violation_tolerance_factor,
        tolerance_factor_hard=config.violation_tolerance_factor)

    # Compute the Van der Waals radius for every atom
    # (the first letter of the atom name is the element type).
    # Shape: (N, 14).
    atomtype_radius = [
        residue_constants.van_der_waals_radius[name[0]]
        for name in residue_constants.atom_types
    ]
    atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
        atomtype_radius, batch['residx_atom14_to_atom37'])

    # Compute the between residue clash loss.
    between_residue_clashes = all_atom.between_residue_clash_loss(
        atom14_pred_positions=atom14_pred_positions,
        atom14_atom_exists=batch['atom14_atom_exists'],
        atom14_atom_radius=atom14_atom_radius,
        residue_index=batch['residue_index'],
        overlap_tolerance_soft=config.clash_overlap_tolerance,
        overlap_tolerance_hard=config.clash_overlap_tolerance)

    # Compute all within-residue violations (clashes,
    # bond length and angle violations).
    restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
        overlap_tolerance=config.clash_overlap_tolerance,
        bond_length_tolerance_factor=config.violation_tolerance_factor)
    atom14_dists_lower_bound = utils.batched_gather(
        restype_atom14_bounds['lower_bound'], batch['aatype'])
    atom14_dists_upper_bound = utils.batched_gather(
        restype_atom14_bounds['upper_bound'], batch['aatype'])
    within_residue_violations = all_atom.within_residue_violations(
        atom14_pred_positions=atom14_pred_positions,
        atom14_atom_exists=batch['atom14_atom_exists'],
        atom14_dists_lower_bound=atom14_dists_lower_bound,
        atom14_dists_upper_bound=atom14_dists_upper_bound,
        tighten_bounds_for_loss=0.0)

    # Combine them to a single per-residue violation mask (used later for LDDT).
    per_residue_violations_mask = jnp.max(jnp.stack([
        connection_violations['per_residue_violation_mask'],
        jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1),
        jnp.max(within_residue_violations['per_atom_violations'], axis=-1)
    ]),
                                          axis=0)

    return {
        'between_residues': {
            'bonds_c_n_loss_mean':
            connection_violations['c_n_loss_mean'],  # ()
            'angles_ca_c_n_loss_mean':
            connection_violations['ca_c_n_loss_mean'],  # ()
            'angles_c_n_ca_loss_mean':
            connection_violations['c_n_ca_loss_mean'],  # ()
            'connections_per_residue_loss_sum':
            connection_violations['per_residue_loss_sum'],  # (N)
            'connections_per_residue_violation_mask':
            connection_violations['per_residue_violation_mask'],  # (N)
            'clashes_mean_loss':
            between_residue_clashes['mean_loss'],  # ()
            'clashes_per_atom_loss_sum':
            between_residue_clashes['per_atom_loss_sum'],  # (N, 14)
            'clashes_per_atom_clash_mask':
            between_residue_clashes['per_atom_clash_mask'],  # (N, 14)
        },
        'within_residues': {
            'per_atom_loss_sum':
            within_residue_violations['per_atom_loss_sum'],  # (N, 14)
            'per_atom_violations':
            within_residue_violations['per_atom_violations'],  # (N, 14),
        },
        'total_per_residue_violations_mask':
        per_residue_violations_mask,  # (N)
    }