def test_frame_aligned_point_error_matches_expected( self, target_positions, pred_positions, expected_alddt): """Tests score matches expected.""" target_frames = get_identity_rigid(2) pred_frames = target_frames frames_mask = np.ones(2) target_positions = r3.vecs_from_tensor(np.array(target_positions)) pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) positions_mask = np.ones(target_positions.x.shape[0]) alddt = all_atom.frame_aligned_point_error(pred_frames, target_frames, frames_mask, pred_positions, target_positions, positions_mask, L1_CLAMP_DISTANCE, L1_CLAMP_DISTANCE, epsilon=0) self.assertAlmostEqual(alddt, expected_alddt)
def frames_and_literature_positions_to_atom14_pos( aatype: jnp.ndarray, # (N) all_frames_to_global: r3.Rigids # (N, 8) ) -> r3.Vecs: # (N, 14) """Put atom literature positions (atom14 encoding) in each rigid group. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 Args: aatype: aatype for each residue. all_frames_to_global: All per residue coordinate frames. Returns: Positions of all atom coordinates in global frame. """ # Pick the appropriate transform for every atom. residx_to_group_idx = utils.batched_gather( residue_constants.restype_atom14_to_rigid_group, aatype) group_mask = jax.nn.one_hot(residx_to_group_idx, num_classes=8) # shape (N, 14, 8) # r3.Rigids with shape (N, 14) map_atoms_to_global = jax.tree_map( lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), all_frames_to_global) # Gather the literature atom positions for each residue. # r3.Vecs with shape (N, 14) lit_positions = r3.vecs_from_tensor( utils.batched_gather( residue_constants.restype_atom14_rigid_group_positions, aatype)) # Transform each atom from its local frame to the global frame. # r3.Vecs with shape (N, 14) pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) # Mask out non-existing atoms. mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) pred_positions = jax.tree_map(lambda x: x * mask, pred_positions) return pred_positions
def sidechain_loss(batch, value, config): """All Atom FAPE Loss using renamed rigids.""" # Rename Frames # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 alt_naming_is_better = value['alt_naming_is_better'] renamed_gt_frames = ((1. - alt_naming_is_better[:, None, None]) * batch['rigidgroups_gt_frames'] + alt_naming_is_better[:, None, None] * batch['rigidgroups_alt_gt_frames']) flat_gt_frames = r3.rigids_from_tensor_flat12( jnp.reshape(renamed_gt_frames, [-1, 12])) flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1]) flat_gt_positions = r3.vecs_from_tensor( jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3])) flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1]) # Compute frame_aligned_point_error score for the final layer. pred_frames = value['sidechains']['frames'] pred_positions = value['sidechains']['atom_pos'] def _slice_last_layer_and_flatten(x): return jnp.reshape(x[-1], [-1]) flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, pred_positions) # FAPE Loss on sidechains fape = all_atom.frame_aligned_point_error( pred_frames=flat_pred_frames, target_frames=flat_gt_frames, frames_mask=flat_frames_mask, pred_positions=flat_pred_positions, target_positions=flat_gt_positions, positions_mask=flat_positions_mask, l1_clamp_distance=config.sidechain.atom_clamp_distance, length_scale=config.sidechain.length_scale) return {'fape': fape, 'loss': fape}
def test_frame_aligned_point_error_perfect_on_global_transform( self, rot_angle, translation): """Tests global transform between target and preds gives perfect score.""" # pylint: disable=bad-whitespace target_positions = np.array([[21.182, 23.095, 19.731], [22.055, 20.919, 17.294], [24.599, 20.005, 15.041], [25.567, 18.214, 12.166], [28.063, 17.082, 10.043], [28.779, 15.569, 6.985], [30.581, 13.815, 4.612], [29.258, 12.193, 2.296]]) # pylint: enable=bad-whitespace global_rigid_transform = get_global_rigid_transform( rot_angle, translation, 1) target_positions = r3.vecs_from_tensor(target_positions) pred_positions = r3.rigids_mul_vecs(global_rigid_transform, target_positions) positions_mask = np.ones(target_positions.x.shape[0]) target_frames = get_identity_rigid(10) pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames) frames_mask = np.ones(10) fape = all_atom.frame_aligned_point_error(pred_frames, target_frames, frames_mask, pred_positions, target_positions, positions_mask, L1_CLAMP_DISTANCE, L1_CLAMP_DISTANCE, epsilon=0) self.assertAlmostEqual(fape, 0.)
def atom37_to_torsion_angles( aatype: jnp.ndarray, # (B, N) all_atom_pos: jnp.ndarray, # (B, N, 37, 3) all_atom_mask: jnp.ndarray, # (B, N, 37) placeholder_for_undefined=False, ) -> Dict[str, jnp.ndarray]: """Computes the 7 torsion angles (in sin, cos encoding) for each residue. The 7 torsion angles are in the order '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', here pre_omega denotes the omega torsion angle between the given amino acid and the previous amino acid. Args: aatype: Amino acid type, given as array with integers. all_atom_pos: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. placeholder_for_undefined: flag denoting whether to set masked torsion angles to zero. Returns: Dict containing: * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final 2 dimensions denote sin and cos respectively * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but with the angle shifted by pi for all chi angles affected by the naming ambiguities. * 'torsion_angles_mask': Mask for which chi angles are present. """ # Map aatype > 20 to 'Unknown' (20). aatype = jnp.minimum(aatype, 20) # Compute the backbone angles. num_batch, num_res = aatype.shape pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32) prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) pad = jnp.zeros([num_batch, 1, 37], jnp.float32) prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) # For each torsion angle collect the 4 atom positions that define this angle. # shape (B, N, atoms=4, xyz=3) pre_omega_atom_pos = jnp.concatenate( [ prev_all_atom_pos[:, :, 1:3, :], # prev CA, C all_atom_pos[:, :, 0:2, :] # this N, CA ], axis=-2) phi_atom_pos = jnp.concatenate( [ prev_all_atom_pos[:, :, 2:3, :], # prev C all_atom_pos[:, :, 0:3, :] # this N, CA, C ], axis=-2) psi_atom_pos = jnp.concatenate( [ all_atom_pos[:, :, 0:3, :], # this N, CA, C all_atom_pos[:, :, 4:5, :] # this O ], axis=-2) # Collect the masks from these atoms. # Shape [batch, num_res] pre_omega_mask = ( jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA phi_mask = ( prev_all_atom_mask[:, :, 2] # prev C * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C psi_mask = ( jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C all_atom_mask[:, :, 4]) # this O # Collect the atoms for the chi-angles. # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. chi_atom_indices = get_chi_atom_indices() # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. atom_indices = utils.batched_gather(params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0) # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. chis_atom_pos = utils.batched_gather(params=all_atom_pos, indices=atom_indices, axis=-2, batch_dims=2) # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. chi_angles_mask = list(residue_constants.chi_angles_mask) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) chi_angles_mask = jnp.asarray(chi_angles_mask) # Compute the chi angle mask. I.e. which chis angles exist according to the # aatype. Shape [batch, num_res, chis=4]. chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, axis=0, batch_dims=0) # Constrain the chis_mask to those chis, where the ground truth coordinates of # all defining four atoms are available. # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4]. chi_angle_atoms_mask = utils.batched_gather(params=all_atom_mask, indices=atom_indices, axis=-1, batch_dims=2) # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32) # Stack all torsion angle atom positions. # Shape (B, N, torsions=7, atoms=4, xyz=3) torsions_atom_pos = jnp.concatenate([ pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :], psi_atom_pos[:, :, None, :, :], chis_atom_pos ], axis=2) # Stack up masks for all torsion angles. # shape (B, N, torsions=7) torsion_angles_mask = jnp.concatenate([ pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None], chis_mask ], axis=2) # Create a frame from the first three atoms: # First atom: point on x-y-plane # Second atom: point on negative x-axis # Third atom: origin # r3.Rigids (B, N, torsions=7) torsion_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) # Compute the position of the forth atom in this frame (y and z coordinate # define the chi angle) # r3.Vecs (B, N, torsions=7) forth_atom_rel_pos = r3.rigids_mul_vecs( r3.invert_rigids(torsion_frames), r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) # Normalize to have the sin and cos of the torsion angle. # jnp.ndarray (B, N, torsions=7, sincos=2) torsion_angles_sin_cos = jnp.stack( [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) torsion_angles_sin_cos /= jnp.sqrt( jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8) # Mirror psi, because we computed it from the Oxygen-atom. torsion_angles_sin_cos *= jnp.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] # Create alternative angles for ambiguous atom names. chi_is_ambiguous = utils.batched_gather( jnp.asarray(residue_constants.chi_pi_periodic), aatype) mirror_torsion_angles = jnp.concatenate( [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) if placeholder_for_undefined: # Add placeholder torsions in place of undefined torsion angles # (e.g. N-terminus pre-omega) placeholder_torsions = jnp.stack([ jnp.ones(torsion_angles_sin_cos.shape[:-1]), jnp.zeros(torsion_angles_sin_cos.shape[:-1]) ], axis=-1) torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) return { 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2) 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2) 'torsion_angles_mask': torsion_angles_mask # (B, N, 7) }
def atom37_to_frames( aatype: jnp.ndarray, # (...) all_atom_positions: jnp.ndarray, # (..., 37, 3) all_atom_mask: jnp.ndarray, # (..., 37) ) -> Dict[str, jnp.ndarray]: """Computes the frames for the up to 8 rigid groups for each residue. The rigid groups are defined by the possible torsions in a given amino acid. We group the atoms according to their dependence on the torsion angles into "rigid groups". E.g., the position of atoms in the chi2-group depend on chi1 and chi2, but do not depend on chi3 or chi4. Jumper et al. (2021) Suppl. Table 2 and corresponding text. Args: aatype: Amino acid type, given as array with integers. all_atom_positions: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. Returns: Dictionary containing: * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' represented as flat 12 dimensional array. * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for the given frame are available in the ground truth, e.g. if they were resolved in the experiment. * 'rigidgroups_group_exists': Mask denoting whether given group is in principle present for given amino acid type. * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is affected by naming ambiguity. * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming corresponding to 'all_atom_positions' represented as flat 12 dimensional array. """ # 0: 'backbone group', # 1: 'pre-omega-group', (empty) # 2: 'phi-group', (currently empty, because it defines only hydrogens) # 3: 'psi-group', # 4,5,6,7: 'chi1,2,3,4-group' aatype_in_shape = aatype.shape # If there is a batch axis, just flatten it away, and reshape everything # back at the end of the function. aatype = jnp.reshape(aatype, [-1]) all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3]) all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) # Create an array with the atom names. # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) # 0: backbone frame restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] # 3: 'psi-group' restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] # 4,5,6,7: 'chi1,2,3,4-group' for restype, restype_letter in enumerate(residue_constants.restypes): resname = residue_constants.restype_1to3[restype_letter] for chi_idx in range(4): if residue_constants.chi_angles_mask[restype][chi_idx]: atom_names = residue_constants.chi_angles_atoms[resname][ chi_idx] restype_rigidgroup_base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] # Create mask for existing rigid groups. restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_mask[:, 0] = 1 restype_rigidgroup_mask[:, 3] = 1 restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask # Translate atom names into atom37 indices. lookuptable = residue_constants.atom_order.copy() lookuptable[''] = 0 restype_rigidgroup_base_atom37_idx = np.vectorize( lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names) # Compute the gather indices for all residues in the chain. # shape (N, 8, 3) residx_rigidgroup_base_atom37_idx = utils.batched_gather( restype_rigidgroup_base_atom37_idx, aatype) # Gather the base atom positions for each rigid group. base_atom_pos = utils.batched_gather(all_atom_positions, residx_rigidgroup_base_atom37_idx, batch_dims=1) # Compute the Rigids. gt_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :])) # Compute a mask whether the group exists. # (N, 8) group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) # Compute a mask whether ground truth exists for the group gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) all_atom_mask.astype(jnp.float32), residx_rigidgroup_base_atom37_idx, batch_dims=1) gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) # Adapt backbone frame to old convention (mirror x-axis and z-axis). rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) rots[0, 0, 0] = -1 rots[0, 2, 2] = -1 gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) # The frames for ambiguous rigid groups are just rotated by 180 degree around # the x-axis. The ambiguous group is always the last chi-group. restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): restype = residue_constants.restype_order[ residue_constants.restype_3to1[resname]] chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 # Gather the ambiguity information for each residue. residx_rigidgroup_is_ambiguous = utils.batched_gather( restype_rigidgroup_is_ambiguous, aatype) residx_rigidgroup_ambiguity_rot = utils.batched_gather( restype_rigidgroup_rots, aatype) # Create the alternative ground truth frames. alt_gt_frames = r3.rigids_mul_rots( gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) # reshape back to original residue layout gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, )) group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, )) gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) residx_rigidgroup_is_ambiguous = jnp.reshape( residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, )) alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12, aatype_in_shape + ( 8, 12, )) return { 'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12) 'rigidgroups_gt_exists': gt_exists, # (..., 8) 'rigidgroups_group_exists': group_exists, # (..., 8) 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (..., 8) 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12) }