示例#1
0
    def test_frame_aligned_point_error_matches_expected(
            self, target_positions, pred_positions, expected_alddt):
        """Tests score matches expected."""

        target_frames = get_identity_rigid(2)
        pred_frames = target_frames
        frames_mask = np.ones(2)

        target_positions = r3.vecs_from_tensor(np.array(target_positions))
        pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
        positions_mask = np.ones(target_positions.x.shape[0])

        alddt = all_atom.frame_aligned_point_error(pred_frames,
                                                   target_frames,
                                                   frames_mask,
                                                   pred_positions,
                                                   target_positions,
                                                   positions_mask,
                                                   L1_CLAMP_DISTANCE,
                                                   L1_CLAMP_DISTANCE,
                                                   epsilon=0)
        self.assertAlmostEqual(alddt, expected_alddt)
示例#2
0
def frames_and_literature_positions_to_atom14_pos(
        aatype: jnp.ndarray,  # (N)
        all_frames_to_global: r3.Rigids  # (N, 8)
) -> r3.Vecs:  # (N, 14)
    """Put atom literature positions (atom14 encoding) in each rigid group.

  Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11

  Args:
    aatype: aatype for each residue.
    all_frames_to_global: All per residue coordinate frames.
  Returns:
    Positions of all atom coordinates in global frame.
  """

    # Pick the appropriate transform for every atom.
    residx_to_group_idx = utils.batched_gather(
        residue_constants.restype_atom14_to_rigid_group, aatype)
    group_mask = jax.nn.one_hot(residx_to_group_idx,
                                num_classes=8)  # shape (N, 14, 8)

    # r3.Rigids with shape (N, 14)
    map_atoms_to_global = jax.tree_map(
        lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
        all_frames_to_global)

    # Gather the literature atom positions for each residue.
    # r3.Vecs with shape (N, 14)
    lit_positions = r3.vecs_from_tensor(
        utils.batched_gather(
            residue_constants.restype_atom14_rigid_group_positions, aatype))

    # Transform each atom from its local frame to the global frame.
    # r3.Vecs with shape (N, 14)
    pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions)

    # Mask out non-existing atoms.
    mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
    pred_positions = jax.tree_map(lambda x: x * mask, pred_positions)

    return pred_positions
示例#3
0
def sidechain_loss(batch, value, config):
    """All Atom FAPE Loss using renamed rigids."""
    # Rename Frames
    # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
    alt_naming_is_better = value['alt_naming_is_better']
    renamed_gt_frames = ((1. - alt_naming_is_better[:, None, None]) *
                         batch['rigidgroups_gt_frames'] +
                         alt_naming_is_better[:, None, None] *
                         batch['rigidgroups_alt_gt_frames'])

    flat_gt_frames = r3.rigids_from_tensor_flat12(
        jnp.reshape(renamed_gt_frames, [-1, 12]))
    flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1])

    flat_gt_positions = r3.vecs_from_tensor(
        jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3]))
    flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1])

    # Compute frame_aligned_point_error score for the final layer.
    pred_frames = value['sidechains']['frames']
    pred_positions = value['sidechains']['atom_pos']

    def _slice_last_layer_and_flatten(x):
        return jnp.reshape(x[-1], [-1])

    flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames)
    flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten,
                                       pred_positions)
    # FAPE Loss on sidechains
    fape = all_atom.frame_aligned_point_error(
        pred_frames=flat_pred_frames,
        target_frames=flat_gt_frames,
        frames_mask=flat_frames_mask,
        pred_positions=flat_pred_positions,
        target_positions=flat_gt_positions,
        positions_mask=flat_positions_mask,
        l1_clamp_distance=config.sidechain.atom_clamp_distance,
        length_scale=config.sidechain.length_scale)

    return {'fape': fape, 'loss': fape}
示例#4
0
    def test_frame_aligned_point_error_perfect_on_global_transform(
            self, rot_angle, translation):
        """Tests global transform between target and preds gives perfect score."""

        # pylint: disable=bad-whitespace
        target_positions = np.array([[21.182, 23.095, 19.731],
                                     [22.055, 20.919, 17.294],
                                     [24.599, 20.005, 15.041],
                                     [25.567, 18.214, 12.166],
                                     [28.063, 17.082, 10.043],
                                     [28.779, 15.569, 6.985],
                                     [30.581, 13.815, 4.612],
                                     [29.258, 12.193, 2.296]])
        # pylint: enable=bad-whitespace
        global_rigid_transform = get_global_rigid_transform(
            rot_angle, translation, 1)

        target_positions = r3.vecs_from_tensor(target_positions)
        pred_positions = r3.rigids_mul_vecs(global_rigid_transform,
                                            target_positions)
        positions_mask = np.ones(target_positions.x.shape[0])

        target_frames = get_identity_rigid(10)
        pred_frames = r3.rigids_mul_rigids(global_rigid_transform,
                                           target_frames)
        frames_mask = np.ones(10)

        fape = all_atom.frame_aligned_point_error(pred_frames,
                                                  target_frames,
                                                  frames_mask,
                                                  pred_positions,
                                                  target_positions,
                                                  positions_mask,
                                                  L1_CLAMP_DISTANCE,
                                                  L1_CLAMP_DISTANCE,
                                                  epsilon=0)
        self.assertAlmostEqual(fape, 0.)
示例#5
0
def atom37_to_torsion_angles(
    aatype: jnp.ndarray,  # (B, N)
    all_atom_pos: jnp.ndarray,  # (B, N, 37, 3)
    all_atom_mask: jnp.ndarray,  # (B, N, 37)
    placeholder_for_undefined=False,
) -> Dict[str, jnp.ndarray]:
    """Computes the 7 torsion angles (in sin, cos encoding) for each residue.

  The 7 torsion angles are in the order
  '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
  here pre_omega denotes the omega torsion angle between the given amino acid
  and the previous amino acid.

  Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_pos: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
    placeholder_for_undefined: flag denoting whether to set masked torsion
      angles to zero.
  Returns:
    Dict containing:
      * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
        2 dimensions denote sin and cos respectively
      * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
        with the angle shifted by pi for all chi angles affected by the naming
        ambiguities.
      * 'torsion_angles_mask': Mask for which chi angles are present.
  """

    # Map aatype > 20 to 'Unknown' (20).
    aatype = jnp.minimum(aatype, 20)

    # Compute the backbone angles.
    num_batch, num_res = aatype.shape

    pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32)
    prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]],
                                        axis=1)

    pad = jnp.zeros([num_batch, 1, 37], jnp.float32)
    prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]],
                                         axis=1)

    # For each torsion angle collect the 4 atom positions that define this angle.
    # shape (B, N, atoms=4, xyz=3)
    pre_omega_atom_pos = jnp.concatenate(
        [
            prev_all_atom_pos[:, :, 1:3, :],  # prev CA, C
            all_atom_pos[:, :, 0:2, :]  # this N, CA
        ],
        axis=-2)
    phi_atom_pos = jnp.concatenate(
        [
            prev_all_atom_pos[:, :, 2:3, :],  # prev C
            all_atom_pos[:, :, 0:3, :]  # this N, CA, C
        ],
        axis=-2)
    psi_atom_pos = jnp.concatenate(
        [
            all_atom_pos[:, :, 0:3, :],  # this N, CA, C
            all_atom_pos[:, :, 4:5, :]  # this O
        ],
        axis=-2)

    # Collect the masks from these atoms.
    # Shape [batch, num_res]
    pre_omega_mask = (
        jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1)  # prev CA, C
        * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1))  # this N, CA
    phi_mask = (
        prev_all_atom_mask[:, :, 2]  # prev C
        * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1))  # this N, CA, C
    psi_mask = (
        jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) *  # this N, CA, C
        all_atom_mask[:, :, 4])  # this O

    # Collect the atoms for the chi-angles.
    # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
    chi_atom_indices = get_chi_atom_indices()
    # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
    atom_indices = utils.batched_gather(params=chi_atom_indices,
                                        indices=aatype,
                                        axis=0,
                                        batch_dims=0)
    # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
    chis_atom_pos = utils.batched_gather(params=all_atom_pos,
                                         indices=atom_indices,
                                         axis=-2,
                                         batch_dims=2)

    # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
    chi_angles_mask = list(residue_constants.chi_angles_mask)
    chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
    chi_angles_mask = jnp.asarray(chi_angles_mask)

    # Compute the chi angle mask. I.e. which chis angles exist according to the
    # aatype. Shape [batch, num_res, chis=4].
    chis_mask = utils.batched_gather(params=chi_angles_mask,
                                     indices=aatype,
                                     axis=0,
                                     batch_dims=0)

    # Constrain the chis_mask to those chis, where the ground truth coordinates of
    # all defining four atoms are available.
    # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4].
    chi_angle_atoms_mask = utils.batched_gather(params=all_atom_mask,
                                                indices=atom_indices,
                                                axis=-1,
                                                batch_dims=2)
    # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
    chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
    chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)

    # Stack all torsion angle atom positions.
    # Shape (B, N, torsions=7, atoms=4, xyz=3)
    torsions_atom_pos = jnp.concatenate([
        pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :],
        psi_atom_pos[:, :, None, :, :], chis_atom_pos
    ],
                                        axis=2)

    # Stack up masks for all torsion angles.
    # shape (B, N, torsions=7)
    torsion_angles_mask = jnp.concatenate([
        pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None],
        chis_mask
    ],
                                          axis=2)

    # Create a frame from the first three atoms:
    # First atom: point on x-y-plane
    # Second atom: point on negative x-axis
    # Third atom: origin
    # r3.Rigids (B, N, torsions=7)
    torsion_frames = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :,
                                                                  1, :]),
        origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
        point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :,
                                                                0, :]))

    # Compute the position of the forth atom in this frame (y and z coordinate
    # define the chi angle)
    # r3.Vecs (B, N, torsions=7)
    forth_atom_rel_pos = r3.rigids_mul_vecs(
        r3.invert_rigids(torsion_frames),
        r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))

    # Normalize to have the sin and cos of the torsion angle.
    # jnp.ndarray (B, N, torsions=7, sincos=2)
    torsion_angles_sin_cos = jnp.stack(
        [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
    torsion_angles_sin_cos /= jnp.sqrt(
        jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) +
        1e-8)

    # Mirror psi, because we computed it from the Oxygen-atom.
    torsion_angles_sin_cos *= jnp.asarray([1., 1., -1., 1., 1., 1.,
                                           1.])[None, None, :, None]

    # Create alternative angles for ambiguous atom names.
    chi_is_ambiguous = utils.batched_gather(
        jnp.asarray(residue_constants.chi_pi_periodic), aatype)
    mirror_torsion_angles = jnp.concatenate(
        [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous],
        axis=-1)
    alt_torsion_angles_sin_cos = (torsion_angles_sin_cos *
                                  mirror_torsion_angles[:, :, :, None])

    if placeholder_for_undefined:
        # Add placeholder torsions in place of undefined torsion angles
        # (e.g. N-terminus pre-omega)
        placeholder_torsions = jnp.stack([
            jnp.ones(torsion_angles_sin_cos.shape[:-1]),
            jnp.zeros(torsion_angles_sin_cos.shape[:-1])
        ],
                                         axis=-1)
        torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
            ...,
            None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
        alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
            ...,
            None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])

    return {
        'torsion_angles_sin_cos': torsion_angles_sin_cos,  # (B, N, 7, 2)
        'alt_torsion_angles_sin_cos':
        alt_torsion_angles_sin_cos,  # (B, N, 7, 2)
        'torsion_angles_mask': torsion_angles_mask  # (B, N, 7)
    }
示例#6
0
def atom37_to_frames(
        aatype: jnp.ndarray,  # (...)
        all_atom_positions: jnp.ndarray,  # (..., 37, 3)
        all_atom_mask: jnp.ndarray,  # (..., 37)
) -> Dict[str, jnp.ndarray]:
    """Computes the frames for the up to 8 rigid groups for each residue.

  The rigid groups are defined by the possible torsions in a given amino acid.
  We group the atoms according to their dependence on the torsion angles into
  "rigid groups".  E.g., the position of atoms in the chi2-group depend on
  chi1 and chi2, but do not depend on chi3 or chi4.
  Jumper et al. (2021) Suppl. Table 2 and corresponding text.

  Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_positions: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
  Returns:
    Dictionary containing:
      * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
           represented as flat 12 dimensional array.
      * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
          the given frame are available in the ground truth, e.g. if they were
          resolved in the experiment.
      * 'rigidgroups_group_exists': Mask denoting whether given group is in
          principle present for given amino acid type.
      * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
          affected by naming ambiguity.
      * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
          corresponding to 'all_atom_positions' represented as flat
          12 dimensional array.
  """
    # 0: 'backbone group',
    # 1: 'pre-omega-group', (empty)
    # 2: 'phi-group', (currently empty, because it defines only hydrogens)
    # 3: 'psi-group',
    # 4,5,6,7: 'chi1,2,3,4-group'
    aatype_in_shape = aatype.shape

    # If there is a batch axis, just flatten it away, and reshape everything
    # back at the end of the function.
    aatype = jnp.reshape(aatype, [-1])
    all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3])
    all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])

    # Create an array with the atom names.
    # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)

    # 0: backbone frame
    restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']

    # 3: 'psi-group'
    restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']

    # 4,5,6,7: 'chi1,2,3,4-group'
    for restype, restype_letter in enumerate(residue_constants.restypes):
        resname = residue_constants.restype_1to3[restype_letter]
        for chi_idx in range(4):
            if residue_constants.chi_angles_mask[restype][chi_idx]:
                atom_names = residue_constants.chi_angles_atoms[resname][
                    chi_idx]
                restype_rigidgroup_base_atom_names[restype, chi_idx +
                                                   4, :] = atom_names[1:]

    # Create mask for existing rigid groups.
    restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_mask[:, 0] = 1
    restype_rigidgroup_mask[:, 3] = 1
    restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask

    # Translate atom names into atom37 indices.
    lookuptable = residue_constants.atom_order.copy()
    lookuptable[''] = 0
    restype_rigidgroup_base_atom37_idx = np.vectorize(
        lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names)

    # Compute the gather indices for all residues in the chain.
    # shape (N, 8, 3)
    residx_rigidgroup_base_atom37_idx = utils.batched_gather(
        restype_rigidgroup_base_atom37_idx, aatype)

    # Gather the base atom positions for each rigid group.
    base_atom_pos = utils.batched_gather(all_atom_positions,
                                         residx_rigidgroup_base_atom37_idx,
                                         batch_dims=1)

    # Compute the Rigids.
    gt_frames = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]),
        origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]),
        point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]))

    # Compute a mask whether the group exists.
    # (N, 8)
    group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype)

    # Compute a mask whether ground truth exists for the group
    gt_atoms_exist = utils.batched_gather(  # shape (N, 8, 3)
        all_atom_mask.astype(jnp.float32),
        residx_rigidgroup_base_atom37_idx,
        batch_dims=1)
    gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists  # (N, 8)

    # Adapt backbone frame to old convention (mirror x-axis and z-axis).
    rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
    rots[0, 0, 0] = -1
    rots[0, 2, 2] = -1
    gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots))

    # The frames for ambiguous rigid groups are just rotated by 180 degree around
    # the x-axis. The ambiguous group is always the last chi-group.
    restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32),
                                      [21, 8, 1, 1])

    for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
        restype = residue_constants.restype_order[
            residue_constants.restype_3to1[resname]]
        chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
        restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
        restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1

    # Gather the ambiguity information for each residue.
    residx_rigidgroup_is_ambiguous = utils.batched_gather(
        restype_rigidgroup_is_ambiguous, aatype)
    residx_rigidgroup_ambiguity_rot = utils.batched_gather(
        restype_rigidgroup_rots, aatype)

    # Create the alternative ground truth frames.
    alt_gt_frames = r3.rigids_mul_rots(
        gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot))

    gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames)
    alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames)

    # reshape back to original residue layout
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, ))
    group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, ))
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    residx_rigidgroup_is_ambiguous = jnp.reshape(
        residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, ))
    alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12,
                                       aatype_in_shape + (
                                           8,
                                           12,
                                       ))

    return {
        'rigidgroups_gt_frames': gt_frames_flat12,  # (..., 8, 12)
        'rigidgroups_gt_exists': gt_exists,  # (..., 8)
        'rigidgroups_group_exists': group_exists,  # (..., 8)
        'rigidgroups_group_is_ambiguous':
        residx_rigidgroup_is_ambiguous,  # (..., 8)
        'rigidgroups_alt_gt_frames': alt_gt_frames_flat12,  # (..., 8, 12)
    }