def frames_and_literature_positions_to_atom14_pos( aatype: paddle.Tensor, # (B, N) all_frames_to_global: r3.Rigids # (B, N, 8) ) -> r3.Vecs: # (B, N, 14) """Put atom literature positions (atom14 encoding) in each rigid group. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 Args: aatype: aatype for each residue. all_frames_to_global: All per residue coordinate frames. Returns: Positions of all atom coordinates in global frame. """ # Pick the appropriate transform for every atom. restype_atom14_to_rigid_group = paddle.to_tensor( residue_constants.restype_atom14_to_rigid_group) residx_to_group_idx = utils.batched_gather( restype_atom14_to_rigid_group[None, ...], aatype, batch_dims=1) # 8 rigid groups: # 0: 'backbone group', # 1: 'pre-omega-group', (empty) # 2: 'phi-group', (currently empty, because it defines only hydrogens) # 3: 'psi-group', # 4,5,6,7: 'chi1,2,3,4-group' # (B, N, 14, 8) group_mask = nn.functional.one_hot( residx_to_group_idx, num_classes=8) def _convert(x, y): return paddle.sum(paddle.unsqueeze(x, -2) * y, axis=-1) # r3.Rigids with shape (B, N, 14) map_atoms_to_global = r3.Rigids( rot=all_frames_to_global.rot.map(_convert, group_mask), trans=all_frames_to_global.trans.map(_convert, group_mask)) # Gather the literature atom positions for each residue. # r3.Vecs with shape (B, N, 14) restype_atom14_rigid_group_positions = paddle.to_tensor( residue_constants.restype_atom14_rigid_group_positions) lit_positions = r3.vecs_from_tensor( utils.batched_gather( restype_atom14_rigid_group_positions[None, ...], aatype, batch_dims=1)) # Transform each atom from its local frame to the global frame. # r3.Vecs with shape (B, N, 14) pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) # Mask out non-existing atoms. restype_atom14_mask = paddle.to_tensor( residue_constants.restype_atom14_mask) mask = utils.batched_gather( restype_atom14_mask[None, ...], aatype, batch_dims=1) pred_positions = pred_positions.map(lambda x, m: x * m, mask) return pred_positions
def concat_rigids(*arg): """concat along the last axis of rot.xx and trans.x""" assert len(arg) > 1 assert len(arg[0].rot.xx.shape) == len(arg[1].rot.xx.shape) rotation = paddle.concat([r.rot.rotation for r in arg], axis=-3) translation = paddle.concat([r.trans.translation for r in arg], axis=-2) return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
def unsqueeze_rigids(rigid, axis=-1): """add an axis in the axis of rot.xx and trans.x""" if axis < 0: axis_t = axis - 1 axis_r = axis - 2 else: axis_t = axis axis_r = axis rotation = paddle.unsqueeze(rigid.rot.rotation, axis=axis_r) translation = paddle.unsqueeze(rigid.trans.translation, axis=axis_t) return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
def sidechain_loss(batch, value, config): """All Atom FAPE Loss using renamed rigids.""" # Rename Frames # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 alt_naming_is_better = value['alt_naming_is_better'] renamed_gt_frames = ( (1. - alt_naming_is_better[:, :, None, None]) * batch['rigidgroups_gt_frames'] + alt_naming_is_better[:, :, None, None] * batch['rigidgroups_alt_gt_frames']) batch_size = renamed_gt_frames.shape[0] flat_gt_frames = r3.rigids_from_tensor_flat12( paddle.reshape(renamed_gt_frames, [batch_size, -1, 12])) flat_frames_mask = paddle.reshape(batch['rigidgroups_gt_exists'], [batch_size, -1]) flat_gt_positions = r3.vecs_from_tensor( paddle.reshape(value['renamed_atom14_gt_positions'], [batch_size, -1, 3])) flat_positions_mask = paddle.reshape(value['renamed_atom14_gt_exists'], [batch_size, -1]) # Compute frame_aligned_point_error score for the final layer. pred_frames_rot = value['sidechains']['frames_rot'] pred_frames_trans = value['sidechains']['frames_trans'] tmp_rots = paddle.reshape(pred_frames_rot[-1], [batch_size, -1, 3, 3]) tmp_vecs = paddle.reshape(pred_frames_trans[-1], [batch_size, -1, 3]) tmp_rots = r3.rots_from_tensor3x3(tmp_rots) tmp_vecs = r3.vecs_from_tensor(tmp_vecs) flat_pred_frames = r3.Rigids(rot=tmp_rots, trans=tmp_vecs) pred_positions = value['sidechains']['atom_pos'] pred_positions = paddle.reshape(pred_positions[-1], [batch_size, -1, 3]) flat_pred_positions = r3.vecs_from_tensor(pred_positions) # FAPE Loss on sidechains fape = all_atom.frame_aligned_point_error( pred_frames=flat_pred_frames, target_frames=flat_gt_frames, frames_mask=flat_frames_mask, pred_positions=flat_pred_positions, target_positions=flat_gt_positions, positions_mask=flat_positions_mask, l1_clamp_distance=config.sidechain.atom_clamp_distance, length_scale=config.sidechain.length_scale) return { 'fape': fape, 'loss': fape}
def slice_rigids(rigid, start, end): """slice along the last axis of rot.xx and trans.x""" assert len(rigid.rot.xx.shape) == 3 rotation = rigid.rot.rotation[..., start:end, :, :] translation = rigid.trans.translation[..., start:end, :] return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
def backbone_loss(ret, batch, value, config): """Backbone FAPE Loss. Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 Args: ret: Dictionary to write outputs into, needs to contain 'loss'. batch: Batch, needs to contain 'backbone_affine_tensor', 'backbone_affine_mask'. value: Dictionary containing structure module output, needs to contain 'traj', a trajectory of rigids. config: Configuration of loss, should contain 'fape.clamp_distance' and 'fape.loss_unit_distance'. """ affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj']) rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory) gt_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32') gt_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32') gt_affine = quat_affine.QuatAffine( quaternion=None, translation=gt_trans, rotation=gt_rot) gt_rigid = r3.rigids_from_quataffine(gt_affine) backbone_mask = batch['backbone_affine_mask'] backbone_mask = paddle.to_tensor(backbone_mask) fape_loss_fn = functools.partial( all_atom.frame_aligned_point_error, l1_clamp_distance=config.fape.clamp_distance, length_scale=config.fape.loss_unit_distance) fape_loss = [] index = 0 for rigid_trajectory_rot_item,rigid_trajectory_trans_item in zip(rigid_trajectory.rot,rigid_trajectory.trans): rigid_trajectory_item = r3.Rigids(rigid_trajectory_rot_item, rigid_trajectory_trans_item) index+=1 middle_fape_loss = fape_loss_fn(rigid_trajectory_item, gt_rigid, backbone_mask, rigid_trajectory_trans_item, gt_rigid.trans, backbone_mask) fape_loss.append(middle_fape_loss) fape_loss = paddle.stack(fape_loss) if 'use_clamped_fape' in batch: # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details" use_clamped_fape = batch['use_clamped_fape'][0, 0] unclamped_fape_loss_fn = functools.partial( all_atom.frame_aligned_point_error, l1_clamp_distance=None, length_scale=config.fape.loss_unit_distance) fape_loss_unclamped = [] index_t = 0 for rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t in zip(rigid_trajectory.rot, rigid_trajectory.trans): rigid_trajectory_item_t = r3.Rigids(rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t) index_t+=1 middle_fape_loss_t = unclamped_fape_loss_fn(rigid_trajectory_item_t, gt_rigid, backbone_mask, rigid_trajectory_trans_item_t, gt_rigid.trans, backbone_mask) fape_loss_unclamped.append(middle_fape_loss_t) fape_loss_unclamped = paddle.stack(fape_loss_unclamped) fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) ret['fape'] = fape_loss[-1] ret['backbone_fape'] = paddle.mean(fape_loss) ret['loss'] += paddle.mean(fape_loss)