Пример #1
0
def frames_and_literature_positions_to_atom14_pos(
        aatype: paddle.Tensor,  # (B, N)
        all_frames_to_global: r3.Rigids  # (B, N, 8)
) -> r3.Vecs:  # (B, N, 14)
    """Put atom literature positions (atom14 encoding) in each rigid group.

    Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11

    Args:
        aatype: aatype for each residue.
        all_frames_to_global: All per residue coordinate frames.
    Returns:
        Positions of all atom coordinates in global frame.
    """
    # Pick the appropriate transform for every atom.
    restype_atom14_to_rigid_group = paddle.to_tensor(
        residue_constants.restype_atom14_to_rigid_group)
    residx_to_group_idx = utils.batched_gather(
        restype_atom14_to_rigid_group[None, ...],
        aatype, batch_dims=1)

    # 8 rigid groups:
    # 0: 'backbone group',
    # 1: 'pre-omega-group', (empty)
    # 2: 'phi-group', (currently empty, because it defines only hydrogens)
    # 3: 'psi-group',
    # 4,5,6,7: 'chi1,2,3,4-group'
    # (B, N, 14, 8)
    group_mask = nn.functional.one_hot(
        residx_to_group_idx, num_classes=8)

    def _convert(x, y):
        return paddle.sum(paddle.unsqueeze(x, -2) * y, axis=-1)

    # r3.Rigids with shape (B, N, 14)
    map_atoms_to_global = r3.Rigids(
        rot=all_frames_to_global.rot.map(_convert, group_mask),
        trans=all_frames_to_global.trans.map(_convert, group_mask))

    # Gather the literature atom positions for each residue.
    # r3.Vecs with shape (B, N, 14)
    restype_atom14_rigid_group_positions = paddle.to_tensor(
        residue_constants.restype_atom14_rigid_group_positions)
    lit_positions = r3.vecs_from_tensor(
        utils.batched_gather(
            restype_atom14_rigid_group_positions[None, ...],
            aatype, batch_dims=1))

    # Transform each atom from its local frame to the global frame.
    # r3.Vecs with shape (B, N, 14)
    pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions)

    # Mask out non-existing atoms.
    restype_atom14_mask = paddle.to_tensor(
        residue_constants.restype_atom14_mask)
    mask = utils.batched_gather(
        restype_atom14_mask[None, ...], aatype, batch_dims=1)
    pred_positions = pred_positions.map(lambda x, m: x * m, mask)

    return pred_positions
Пример #2
0
 def concat_rigids(*arg):
     """concat along the last axis of rot.xx and trans.x"""
     assert len(arg) > 1
     assert len(arg[0].rot.xx.shape) == len(arg[1].rot.xx.shape)
     rotation = paddle.concat([r.rot.rotation for r in arg], axis=-3)
     translation = paddle.concat([r.trans.translation for r in arg], axis=-2)
     return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
Пример #3
0
    def unsqueeze_rigids(rigid, axis=-1):
        """add an axis in the axis of rot.xx and trans.x"""
        if axis < 0:
            axis_t = axis - 1
            axis_r = axis - 2
        else:
            axis_t = axis
            axis_r = axis

        rotation = paddle.unsqueeze(rigid.rot.rotation, axis=axis_r)
        translation = paddle.unsqueeze(rigid.trans.translation, axis=axis_t)
        return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
Пример #4
0
def sidechain_loss(batch, value, config):
    """All Atom FAPE Loss using renamed rigids."""
    # Rename Frames
    # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
    alt_naming_is_better = value['alt_naming_is_better']

    renamed_gt_frames = (
        (1. - alt_naming_is_better[:, :, None, None])
        * batch['rigidgroups_gt_frames']
        + alt_naming_is_better[:, :, None, None]
        * batch['rigidgroups_alt_gt_frames'])

    batch_size = renamed_gt_frames.shape[0]
    flat_gt_frames = r3.rigids_from_tensor_flat12(
            paddle.reshape(renamed_gt_frames, [batch_size, -1, 12]))
    flat_frames_mask = paddle.reshape(batch['rigidgroups_gt_exists'], [batch_size, -1])

    flat_gt_positions = r3.vecs_from_tensor(
            paddle.reshape(value['renamed_atom14_gt_positions'], [batch_size, -1, 3]))
    flat_positions_mask = paddle.reshape(value['renamed_atom14_gt_exists'], [batch_size, -1])

    # Compute frame_aligned_point_error score for the final layer.
    pred_frames_rot = value['sidechains']['frames_rot']
    pred_frames_trans = value['sidechains']['frames_trans']
    tmp_rots = paddle.reshape(pred_frames_rot[-1], [batch_size, -1, 3, 3])
    tmp_vecs = paddle.reshape(pred_frames_trans[-1], [batch_size, -1, 3])
    tmp_rots = r3.rots_from_tensor3x3(tmp_rots)
    tmp_vecs = r3.vecs_from_tensor(tmp_vecs) 
    flat_pred_frames = r3.Rigids(rot=tmp_rots, trans=tmp_vecs)

    pred_positions = value['sidechains']['atom_pos']
    pred_positions = paddle.reshape(pred_positions[-1], [batch_size, -1, 3])
    flat_pred_positions = r3.vecs_from_tensor(pred_positions)

    # FAPE Loss on sidechains
    fape = all_atom.frame_aligned_point_error(
        pred_frames=flat_pred_frames,
        target_frames=flat_gt_frames,
        frames_mask=flat_frames_mask,
        pred_positions=flat_pred_positions,
        target_positions=flat_gt_positions,
        positions_mask=flat_positions_mask,
        l1_clamp_distance=config.sidechain.atom_clamp_distance,
        length_scale=config.sidechain.length_scale)

    return {
      'fape': fape,
      'loss': fape}
Пример #5
0
 def slice_rigids(rigid, start, end):
     """slice along the last axis of rot.xx and trans.x"""
     assert len(rigid.rot.xx.shape) == 3
     rotation = rigid.rot.rotation[..., start:end, :, :]
     translation = rigid.trans.translation[..., start:end, :]
     return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation))
Пример #6
0
def backbone_loss(ret, batch, value, config):
    """Backbone FAPE Loss.

    Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17

    Args:
        ret: Dictionary to write outputs into, needs to contain 'loss'.
        batch: Batch, needs to contain 'backbone_affine_tensor',
        'backbone_affine_mask'.
        value: Dictionary containing structure module output, needs to contain
        'traj', a trajectory of rigids.
        config: Configuration of loss, should contain 'fape.clamp_distance' and
        'fape.loss_unit_distance'.
    """    
    affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
    rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)

    gt_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32')
    gt_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32')
    gt_affine = quat_affine.QuatAffine(
        quaternion=None,
        translation=gt_trans,
        rotation=gt_rot)
    gt_rigid = r3.rigids_from_quataffine(gt_affine)
    backbone_mask = batch['backbone_affine_mask']
    backbone_mask = paddle.to_tensor(backbone_mask) 

    fape_loss_fn = functools.partial(
        all_atom.frame_aligned_point_error,
        l1_clamp_distance=config.fape.clamp_distance,
        length_scale=config.fape.loss_unit_distance)
    
    fape_loss = []
    index = 0
    for rigid_trajectory_rot_item,rigid_trajectory_trans_item in zip(rigid_trajectory.rot,rigid_trajectory.trans):
        rigid_trajectory_item = r3.Rigids(rigid_trajectory_rot_item, rigid_trajectory_trans_item)
        index+=1
        middle_fape_loss = fape_loss_fn(rigid_trajectory_item, gt_rigid, backbone_mask,
                           rigid_trajectory_trans_item, gt_rigid.trans,
                           backbone_mask)
        fape_loss.append(middle_fape_loss)
    fape_loss = paddle.stack(fape_loss)

    if 'use_clamped_fape' in batch:
        # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details"
        use_clamped_fape = batch['use_clamped_fape'][0, 0]

        unclamped_fape_loss_fn = functools.partial(
            all_atom.frame_aligned_point_error,
            l1_clamp_distance=None,
            length_scale=config.fape.loss_unit_distance)

        fape_loss_unclamped = []
        index_t = 0
        for rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t in zip(rigid_trajectory.rot, rigid_trajectory.trans):
            rigid_trajectory_item_t = r3.Rigids(rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t)
            index_t+=1
            middle_fape_loss_t = unclamped_fape_loss_fn(rigid_trajectory_item_t, gt_rigid, backbone_mask,
                            rigid_trajectory_trans_item_t, gt_rigid.trans,
                            backbone_mask)
            fape_loss_unclamped.append(middle_fape_loss_t)
        fape_loss_unclamped = paddle.stack(fape_loss_unclamped)

        fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape))
    
    ret['fape'] = fape_loss[-1]
    ret['backbone_fape'] = paddle.mean(fape_loss)
    ret['loss'] += paddle.mean(fape_loss)