def get_alt_atom14(aatype, positions, mask): """Get alternative atom14 positions. Constructs renamed atom positions for ambiguous residues. Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree- rotation-symmetry" Args: aatype: Amino acid at given position positions: Atom positions as r3.Vecs in atom14 representation, (N, 14) mask: Atom masks in atom14 representation, (N, 14) Returns: renamed atom positions, renamed atom mask """ # pick the transformation matrices for the given residue sequence # shape (num_res, 14, 14) renaming_transform = utils.batched_gather(jnp.asarray(RENAMING_MATRICES), aatype) positions = jax.tree_map(lambda x: x[:, :, None], positions) alternative_positions = jax.tree_map(lambda x: jnp.sum(x, axis=1), positions * renaming_transform) # Create the mask for the alternative ground truth (differs from the # ground truth mask, if only one of the atoms in an ambiguous pair has a # ground truth position) alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) return alternative_positions, alternative_mask
def frames_and_literature_positions_to_atom14_pos( aatype: jnp.ndarray, # (N) all_frames_to_global: r3.Rigids # (N, 8) ) -> r3.Vecs: # (N, 14) """Put atom literature positions (atom14 encoding) in each rigid group. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 Args: aatype: aatype for each residue. all_frames_to_global: All per residue coordinate frames. Returns: Positions of all atom coordinates in global frame. """ # Pick the appropriate transform for every atom. residx_to_group_idx = utils.batched_gather( residue_constants.restype_atom14_to_rigid_group, aatype) group_mask = jax.nn.one_hot(residx_to_group_idx, num_classes=8) # shape (N, 14, 8) # r3.Rigids with shape (N, 14) map_atoms_to_global = jax.tree_map( lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), all_frames_to_global) # Gather the literature atom positions for each residue. # r3.Vecs with shape (N, 14) lit_positions = r3.vecs_from_tensor( utils.batched_gather( residue_constants.restype_atom14_rigid_group_positions, aatype)) # Transform each atom from its local frame to the global frame. # r3.Vecs with shape (N, 14) pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) # Mask out non-existing atoms. mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) pred_positions = jax.tree_map(lambda x: x * mask, pred_positions) return pred_positions
def atom37_to_atom14( atom37_data: jnp.ndarray, # (N, 37, ...) batch: Dict[str, jnp.ndarray]) -> jnp.ndarray: # (N, 14, ...) """Convert atom14 to atom37 representation.""" assert len(atom37_data.shape) in [2, 3] assert 'residx_atom14_to_atom37' in batch assert 'atom14_atom_exists' in batch atom14_data = utils.batched_gather(atom37_data, batch['residx_atom14_to_atom37'], batch_dims=1) if len(atom37_data.shape) == 2: atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype) elif len(atom37_data.shape) == 3: atom14_data *= batch['atom14_atom_exists'][:, :, None].astype( atom14_data.dtype) return atom14_data
def torsion_angles_to_frames( aatype: jnp.ndarray, # (N) backb_to_global: r3.Rigids, # (N) torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) ) -> r3.Rigids: # (N, 8) """Compute rigid group frames from torsion angles. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10 Jumper et al. (2021) Suppl. Alg. 25 "makeRotX" Args: aatype: aatype for each residue backb_to_global: Rigid transformations describing transformation from backbone frame to global frame. torsion_angles_sin_cos: sin and cosine of the 7 torsion angles Returns: Frames corresponding to all the Sidechain Rigid Transforms """ assert len(aatype.shape) == 1 assert len(backb_to_global.rot.xx.shape) == 1 assert len(torsion_angles_sin_cos.shape) == 3 assert torsion_angles_sin_cos.shape[1] == 7 assert torsion_angles_sin_cos.shape[2] == 2 # Gather the default frames for all rigid groups. # r3.Rigids with shape (N, 8) m = utils.batched_gather( residue_constants.restype_rigid_group_default_frame, aatype) default_frames = r3.rigids_from_tensor4x4(m) # Create the rotation matrices according to the given angles (each frame is # defined such that its rotation is around the x-axis). sin_angles = torsion_angles_sin_cos[..., 0] cos_angles = torsion_angles_sin_cos[..., 1] # insert zero rotation for backbone group. num_residues, = aatype.shape sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles], axis=-1) cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles], axis=-1) zeros = jnp.zeros_like(sin_angles) ones = jnp.ones_like(sin_angles) # all_rots are r3.Rots with shape (N, 8) all_rots = r3.Rots(ones, zeros, zeros, zeros, cos_angles, -sin_angles, zeros, sin_angles, cos_angles) # Apply rotations to the frames. all_frames = r3.rigids_mul_rots(default_frames, all_rots) # chi2, chi3, and chi4 frames do not transform to the backbone frame but to # the previous frame. So chain them up accordingly. chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames) chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames) chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames) chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames) chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame) chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame) chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame) # Recombine them to a r3.Rigids with shape (N, 8). def _concat_frames(xall, x5, x6, x7): return jnp.concatenate( [xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1) all_frames_to_backb = jax.tree_map(_concat_frames, all_frames, chi2_frame_to_backb, chi3_frame_to_backb, chi4_frame_to_backb) # Create the global frames. # shape (N, 8) all_frames_to_global = r3.rigids_mul_rigids( jax.tree_map(lambda x: x[:, None], backb_to_global), all_frames_to_backb) return all_frames_to_global
def atom37_to_torsion_angles( aatype: jnp.ndarray, # (B, N) all_atom_pos: jnp.ndarray, # (B, N, 37, 3) all_atom_mask: jnp.ndarray, # (B, N, 37) placeholder_for_undefined=False, ) -> Dict[str, jnp.ndarray]: """Computes the 7 torsion angles (in sin, cos encoding) for each residue. The 7 torsion angles are in the order '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', here pre_omega denotes the omega torsion angle between the given amino acid and the previous amino acid. Args: aatype: Amino acid type, given as array with integers. all_atom_pos: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. placeholder_for_undefined: flag denoting whether to set masked torsion angles to zero. Returns: Dict containing: * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final 2 dimensions denote sin and cos respectively * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but with the angle shifted by pi for all chi angles affected by the naming ambiguities. * 'torsion_angles_mask': Mask for which chi angles are present. """ # Map aatype > 20 to 'Unknown' (20). aatype = jnp.minimum(aatype, 20) # Compute the backbone angles. num_batch, num_res = aatype.shape pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32) prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) pad = jnp.zeros([num_batch, 1, 37], jnp.float32) prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) # For each torsion angle collect the 4 atom positions that define this angle. # shape (B, N, atoms=4, xyz=3) pre_omega_atom_pos = jnp.concatenate( [ prev_all_atom_pos[:, :, 1:3, :], # prev CA, C all_atom_pos[:, :, 0:2, :] # this N, CA ], axis=-2) phi_atom_pos = jnp.concatenate( [ prev_all_atom_pos[:, :, 2:3, :], # prev C all_atom_pos[:, :, 0:3, :] # this N, CA, C ], axis=-2) psi_atom_pos = jnp.concatenate( [ all_atom_pos[:, :, 0:3, :], # this N, CA, C all_atom_pos[:, :, 4:5, :] # this O ], axis=-2) # Collect the masks from these atoms. # Shape [batch, num_res] pre_omega_mask = ( jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA phi_mask = ( prev_all_atom_mask[:, :, 2] # prev C * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C psi_mask = ( jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C all_atom_mask[:, :, 4]) # this O # Collect the atoms for the chi-angles. # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. chi_atom_indices = get_chi_atom_indices() # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. atom_indices = utils.batched_gather(params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0) # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. chis_atom_pos = utils.batched_gather(params=all_atom_pos, indices=atom_indices, axis=-2, batch_dims=2) # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. chi_angles_mask = list(residue_constants.chi_angles_mask) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) chi_angles_mask = jnp.asarray(chi_angles_mask) # Compute the chi angle mask. I.e. which chis angles exist according to the # aatype. Shape [batch, num_res, chis=4]. chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, axis=0, batch_dims=0) # Constrain the chis_mask to those chis, where the ground truth coordinates of # all defining four atoms are available. # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4]. chi_angle_atoms_mask = utils.batched_gather(params=all_atom_mask, indices=atom_indices, axis=-1, batch_dims=2) # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32) # Stack all torsion angle atom positions. # Shape (B, N, torsions=7, atoms=4, xyz=3) torsions_atom_pos = jnp.concatenate([ pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :], psi_atom_pos[:, :, None, :, :], chis_atom_pos ], axis=2) # Stack up masks for all torsion angles. # shape (B, N, torsions=7) torsion_angles_mask = jnp.concatenate([ pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None], chis_mask ], axis=2) # Create a frame from the first three atoms: # First atom: point on x-y-plane # Second atom: point on negative x-axis # Third atom: origin # r3.Rigids (B, N, torsions=7) torsion_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) # Compute the position of the forth atom in this frame (y and z coordinate # define the chi angle) # r3.Vecs (B, N, torsions=7) forth_atom_rel_pos = r3.rigids_mul_vecs( r3.invert_rigids(torsion_frames), r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) # Normalize to have the sin and cos of the torsion angle. # jnp.ndarray (B, N, torsions=7, sincos=2) torsion_angles_sin_cos = jnp.stack( [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) torsion_angles_sin_cos /= jnp.sqrt( jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8) # Mirror psi, because we computed it from the Oxygen-atom. torsion_angles_sin_cos *= jnp.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] # Create alternative angles for ambiguous atom names. chi_is_ambiguous = utils.batched_gather( jnp.asarray(residue_constants.chi_pi_periodic), aatype) mirror_torsion_angles = jnp.concatenate( [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) if placeholder_for_undefined: # Add placeholder torsions in place of undefined torsion angles # (e.g. N-terminus pre-omega) placeholder_torsions = jnp.stack([ jnp.ones(torsion_angles_sin_cos.shape[:-1]), jnp.zeros(torsion_angles_sin_cos.shape[:-1]) ], axis=-1) torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) return { 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2) 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2) 'torsion_angles_mask': torsion_angles_mask # (B, N, 7) }
def atom37_to_frames( aatype: jnp.ndarray, # (...) all_atom_positions: jnp.ndarray, # (..., 37, 3) all_atom_mask: jnp.ndarray, # (..., 37) ) -> Dict[str, jnp.ndarray]: """Computes the frames for the up to 8 rigid groups for each residue. The rigid groups are defined by the possible torsions in a given amino acid. We group the atoms according to their dependence on the torsion angles into "rigid groups". E.g., the position of atoms in the chi2-group depend on chi1 and chi2, but do not depend on chi3 or chi4. Jumper et al. (2021) Suppl. Table 2 and corresponding text. Args: aatype: Amino acid type, given as array with integers. all_atom_positions: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. Returns: Dictionary containing: * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' represented as flat 12 dimensional array. * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for the given frame are available in the ground truth, e.g. if they were resolved in the experiment. * 'rigidgroups_group_exists': Mask denoting whether given group is in principle present for given amino acid type. * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is affected by naming ambiguity. * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming corresponding to 'all_atom_positions' represented as flat 12 dimensional array. """ # 0: 'backbone group', # 1: 'pre-omega-group', (empty) # 2: 'phi-group', (currently empty, because it defines only hydrogens) # 3: 'psi-group', # 4,5,6,7: 'chi1,2,3,4-group' aatype_in_shape = aatype.shape # If there is a batch axis, just flatten it away, and reshape everything # back at the end of the function. aatype = jnp.reshape(aatype, [-1]) all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3]) all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) # Create an array with the atom names. # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) # 0: backbone frame restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] # 3: 'psi-group' restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] # 4,5,6,7: 'chi1,2,3,4-group' for restype, restype_letter in enumerate(residue_constants.restypes): resname = residue_constants.restype_1to3[restype_letter] for chi_idx in range(4): if residue_constants.chi_angles_mask[restype][chi_idx]: atom_names = residue_constants.chi_angles_atoms[resname][ chi_idx] restype_rigidgroup_base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] # Create mask for existing rigid groups. restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_mask[:, 0] = 1 restype_rigidgroup_mask[:, 3] = 1 restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask # Translate atom names into atom37 indices. lookuptable = residue_constants.atom_order.copy() lookuptable[''] = 0 restype_rigidgroup_base_atom37_idx = np.vectorize( lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names) # Compute the gather indices for all residues in the chain. # shape (N, 8, 3) residx_rigidgroup_base_atom37_idx = utils.batched_gather( restype_rigidgroup_base_atom37_idx, aatype) # Gather the base atom positions for each rigid group. base_atom_pos = utils.batched_gather(all_atom_positions, residx_rigidgroup_base_atom37_idx, batch_dims=1) # Compute the Rigids. gt_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :])) # Compute a mask whether the group exists. # (N, 8) group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) # Compute a mask whether ground truth exists for the group gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) all_atom_mask.astype(jnp.float32), residx_rigidgroup_base_atom37_idx, batch_dims=1) gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) # Adapt backbone frame to old convention (mirror x-axis and z-axis). rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) rots[0, 0, 0] = -1 rots[0, 2, 2] = -1 gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) # The frames for ambiguous rigid groups are just rotated by 180 degree around # the x-axis. The ambiguous group is always the last chi-group. restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): restype = residue_constants.restype_order[ residue_constants.restype_3to1[resname]] chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 # Gather the ambiguity information for each residue. residx_rigidgroup_is_ambiguous = utils.batched_gather( restype_rigidgroup_is_ambiguous, aatype) residx_rigidgroup_ambiguity_rot = utils.batched_gather( restype_rigidgroup_rots, aatype) # Create the alternative ground truth frames. alt_gt_frames = r3.rigids_mul_rots( gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) # reshape back to original residue layout gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, )) group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, )) gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) residx_rigidgroup_is_ambiguous = jnp.reshape( residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, )) alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12, aatype_in_shape + ( 8, 12, )) return { 'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12) 'rigidgroups_gt_exists': gt_exists, # (..., 8) 'rigidgroups_group_exists': group_exists, # (..., 8) 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (..., 8) 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12) }
def find_structural_violations( batch: Dict[str, jnp.ndarray], atom14_pred_positions: jnp.ndarray, # (N, 14, 3) config: ml_collections.ConfigDict): """Computes several checks for structural violations.""" # Compute between residue backbone violations of bonds and angles. connection_violations = all_atom.between_residue_bond_loss( pred_atom_positions=atom14_pred_positions, pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), residue_index=batch['residue_index'].astype(jnp.float32), aatype=batch['aatype'], tolerance_factor_soft=config.violation_tolerance_factor, tolerance_factor_hard=config.violation_tolerance_factor) # Compute the Van der Waals radius for every atom # (the first letter of the atom name is the element type). # Shape: (N, 14). atomtype_radius = [ residue_constants.van_der_waals_radius[name[0]] for name in residue_constants.atom_types ] atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( atomtype_radius, batch['residx_atom14_to_atom37']) # Compute the between residue clash loss. between_residue_clashes = all_atom.between_residue_clash_loss( atom14_pred_positions=atom14_pred_positions, atom14_atom_exists=batch['atom14_atom_exists'], atom14_atom_radius=atom14_atom_radius, residue_index=batch['residue_index'], overlap_tolerance_soft=config.clash_overlap_tolerance, overlap_tolerance_hard=config.clash_overlap_tolerance) # Compute all within-residue violations (clashes, # bond length and angle violations). restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( overlap_tolerance=config.clash_overlap_tolerance, bond_length_tolerance_factor=config.violation_tolerance_factor) atom14_dists_lower_bound = utils.batched_gather( restype_atom14_bounds['lower_bound'], batch['aatype']) atom14_dists_upper_bound = utils.batched_gather( restype_atom14_bounds['upper_bound'], batch['aatype']) within_residue_violations = all_atom.within_residue_violations( atom14_pred_positions=atom14_pred_positions, atom14_atom_exists=batch['atom14_atom_exists'], atom14_dists_lower_bound=atom14_dists_lower_bound, atom14_dists_upper_bound=atom14_dists_upper_bound, tighten_bounds_for_loss=0.0) # Combine them to a single per-residue violation mask (used later for LDDT). per_residue_violations_mask = jnp.max(jnp.stack([ connection_violations['per_residue_violation_mask'], jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), jnp.max(within_residue_violations['per_atom_violations'], axis=-1) ]), axis=0) return { 'between_residues': { 'bonds_c_n_loss_mean': connection_violations['c_n_loss_mean'], # () 'angles_ca_c_n_loss_mean': connection_violations['ca_c_n_loss_mean'], # () 'angles_c_n_ca_loss_mean': connection_violations['c_n_ca_loss_mean'], # () 'connections_per_residue_loss_sum': connection_violations['per_residue_loss_sum'], # (N) 'connections_per_residue_violation_mask': connection_violations['per_residue_violation_mask'], # (N) 'clashes_mean_loss': between_residue_clashes['mean_loss'], # () 'clashes_per_atom_loss_sum': between_residue_clashes['per_atom_loss_sum'], # (N, 14) 'clashes_per_atom_clash_mask': between_residue_clashes['per_atom_clash_mask'], # (N, 14) }, 'within_residues': { 'per_atom_loss_sum': within_residue_violations['per_atom_loss_sum'], # (N, 14) 'per_atom_violations': within_residue_violations['per_atom_violations'], # (N, 14), }, 'total_per_residue_violations_mask': per_residue_violations_mask, # (N) }