def test_coords_En():
    model = Alphafold2(dim=256,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       use_se3_transformer=False,
                       predict_coords=True,
                       num_backbone_atoms=3)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)
    # get masks : cloud is all points in prot. chain is all for which we have labels
    cloud_mask = scn_cloud_mask(seq, boolean=True)
    flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
    chain_mask = (mask.unsqueeze(-1) * cloud_mask)
    flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

    # put in sidechainnet format
    wrapper = torch.zeros(*cloud_mask.shape,
                          3).to(coords.device).type(coords.type())
    wrapper[cloud_mask] = coords[flat_cloud_mask]

    assert wrapper[chain_mask].shape == coords[
        flat_chain_mask].shape, 'must output coordinates'
def test_templates():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       templates_dim=32,
                       templates_angles_feats_dim=32)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    templates_feats = torch.randn(2, 3, 16, 16, 32)
    templates_angles = torch.randn(2, 3, 16, 32)
    templates_mask = torch.ones(2, 3, 16).bool()

    distogram = model(seq,
                      msa,
                      mask=mask,
                      msa_mask=msa_mask,
                      templates_feats=templates_feats,
                      templates_angles=templates_angles,
                      templates_mask=templates_mask)
    assert True
Exemple #3
0
def test_templates_en():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       template_embedder_type='en',
                       attn_types=('full', 'intra_attn', 'seq_only'))

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    templates_seq = torch.randint(0, 21, (2, 2, 16))
    templates_coors = torch.randn(2, 2, 16, 3)
    templates_mask = torch.ones_like(templates_seq).bool()

    distogram = model(seq,
                      msa,
                      mask=mask,
                      msa_mask=msa_mask,
                      templates_seq=templates_seq,
                      templates_coors=templates_coors,
                      templates_mask=templates_mask)
    assert True
def test_embeddings():
    model = Alphafold2(
        dim = 256,
        depth = 2,
        heads = 8,
        dim_head = 64
    )

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    embedds = torch.randn(2, 1, 16, 1280)

    # without mask
    distogram = model(
        seq,
        mask = mask,
        embedds = embedds,
        msa_mask = None
    )
    
    # with mask
    embedds_mask = torch.ones_like(embedds[..., -1]).bool()
    distogram = model(
        seq,
        mask = mask,
        embedds = embedds,
        msa_mask = embedds_mask
    )
    assert True
def test_confidence_En():
    model = Alphafold2(
        dim = 256,
        depth = 1,
        heads = 8,
        dim_head = 64,
        use_se3_transformer = False,
        predict_coords = True,
        num_backbone_atoms = 3
    )

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords, confidences = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        return_confidence = True
    )
    
    assert coords.shape[:-1] == confidences.shape[:-1]
def test_coords_En_backwards():
    model = Alphafold2(
        dim = 256,
        depth = 2,
        heads = 8,
        dim_head = 64,
        use_se3_transformer = False,
        predict_coords = True,
        num_backbone_atoms = 3
    )

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    coords.sum().backward()
    assert True, 'must be able to go backwards through MDS and center distogram'
def test_real_value_distance_with_coords():
    model = Alphafold2(
        dim = 256,
        depth = 2,
        heads = 8,
        dim_head = 64,
        predict_coords = True,
        predict_real_value_distances = True,
        num_backbone_atoms = 3,
        structure_module_dim = 1,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    assert coords.shape == (2, 16 * 14, 3), 'must output coordinates'
def test_no_msa():
    model = Alphafold2(dim=32, depth=2, heads=2, dim_head=32)

    seq = torch.randint(0, 21, (2, 128))
    mask = torch.ones_like(seq).bool()

    distogram = model(seq, mask=mask)
    assert True
Exemple #9
0
def test_main():
    model = Alphafold2(dim=256, depth=2, heads=8, dim_head=64)

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(seq, msa, mask=mask, msa_mask=msa_mask)
    assert True
def test_reversible():
    model = Alphafold2(dim=32, depth=2, heads=2, dim_head=32, reversible=True)

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(seq, msa, mask=mask, msa_mask=msa_mask)

    distogram.sum().backward()
    assert True
Exemple #11
0
def test_custom_blocks():
    model = Alphafold2(dim=32,
                       heads=2,
                       dim_head=32,
                       custom_block_types=('conv', 'conv', 'self', 'self',
                                           'cross'))

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(seq, msa, mask=mask, msa_mask=msa_mask)
    assert True
def test_msa_tie_row_attn():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       msa_tie_row_attn=True)

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(seq, msa, mask=mask, msa_mask=msa_mask)
    assert True
def test_anglegrams():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       predict_angles=True)

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 128))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    ret = model(seq, msa, mask=mask, msa_mask=msa_mask)
    assert True
Exemple #14
0
def test_use_conv():
    model = Alphafold2(dim=32,
                       depth=4,
                       heads=2,
                       dim_head=32,
                       use_conv=True,
                       dilations=(1, 3, 5))

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(seq, msa, mask=mask, msa_mask=msa_mask)
    assert True
Exemple #15
0
def test_anglegrams():
    model = Alphafold2(dim=256,
                       depth=2,
                       heads=8,
                       dim_head=64,
                       predict_angles=True)

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram, theta, phi, omega = model(seq,
                                         msa,
                                         mask=mask,
                                         msa_mask=msa_mask)
    assert True
Exemple #16
0
def test_templates():
    model = Alphafold2(dim=256, depth=2, heads=8, dim_head=64)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    templates = torch.randint(0, 37, (2, 2, 16, 16))
    templates_mask = torch.ones_like(templates).bool()

    distogram = model(seq,
                      msa,
                      mask=mask,
                      msa_mask=msa_mask,
                      templates=templates,
                      templates_mask=templates_mask)
Exemple #17
0
def test_coords_egnn_backwards():
    model = Alphafold2(dim=256,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       structure_module_type="egnn",
                       predict_coords=True,
                       refine_coords=True)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    coords.sum().backward()
    assert True, 'must be able to go backwards through MDS and center distogram'
def test_confidence():
    model = Alphafold2(dim=256,
                       depth=1,
                       heads=2,
                       dim_head=32,
                       predict_coords=True)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords, confidences = model(seq,
                                msa,
                                mask=mask,
                                msa_mask=msa_mask,
                                return_confidence=True)

    assert coords.shape[:-1] == confidences.shape[:-1]
def test_edges_to_equivariant_network():
    model = Alphafold2(dim=32,
                       depth=1,
                       heads=2,
                       dim_head=32,
                       predict_coords=True,
                       predict_angles=True)

    seq = torch.randint(0, 21, (2, 32))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords, confidences = model(seq,
                                msa,
                                mask=mask,
                                msa_mask=msa_mask,
                                return_confidence=True)
    assert True, 'should run without errors'
def test_mds():
    model = Alphafold2(
        dim=32,
        depth=2,
        heads=2,
        dim_head=32,
        predict_coords=True,
        structure_module_depth=1,
        structure_module_heads=1,
        structure_module_dim_head=1,
    )

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    assert coords.shape == (2, 16, 3), 'must output coordinates'
Exemple #21
0
def test_edges_to_equivariant_network():
    model = Alphafold2(dim=256,
                       depth=1,
                       heads=8,
                       dim_head=64,
                       use_se3_transformer=False,
                       predict_coords=True,
                       predict_angles=True,
                       num_backbone_atoms=3)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords, confidences = model(seq,
                                msa,
                                mask=mask,
                                msa_mask=msa_mask,
                                return_confidence=True)
    assert True, 'should run without errors'
def test_recycling():
    model = Alphafold2(
        dim=128,
        depth=2,
        heads=2,
        dim_head=32,
        predict_coords=True,
    )

    seq = torch.randint(0, 21, (2, 4))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 4))
    msa_mask = torch.ones_like(msa).bool()

    extra_msa = torch.randint(0, 21, (2, 5, 4))
    extra_msa_mask = torch.ones_like(extra_msa).bool()

    coords, ret = model(seq,
                        msa,
                        mask=mask,
                        msa_mask=msa_mask,
                        extra_msa=extra_msa,
                        extra_msa_mask=extra_msa_mask,
                        return_aux_logits=True,
                        return_recyclables=True)

    coords, ret = model(seq,
                        msa,
                        mask=mask,
                        msa_mask=msa_mask,
                        extra_msa=extra_msa,
                        extra_msa_mask=extra_msa_mask,
                        recyclables=ret.recyclables,
                        return_aux_logits=True,
                        return_recyclables=True)

    assert True
Exemple #23
0
def test_custom_coords_module():
    class CustomCoords(nn.Module):
        def __init__(self, dim, structure_module_dim):
            super().__init__()
            self.to_coords = nn.Linear(dim, 3)

        def forward(self, *, distance_pred, trunk_embeds, cloud_mask,
                    **kwargs):
            coords = self.to_coords(trunk_embeds.sum(dim=2))
            coords = repeat(coords,
                            'b n c -> b (n l) c',
                            l=cloud_mask.shape[-1])
            return coords

    coords_module = CustomCoords(dim=32, structure_module_dim=4)

    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       predict_coords=True,
                       refine_coords=True,
                       structure_module_dim=4,
                       structure_module_depth=1,
                       structure_module_heads=1,
                       structure_module_dim_head=1,
                       structure_module_knn=2,
                       coords_module=coords_module)

    seq = torch.randint(0, 21, (2, 8))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    assert coords.shape == (2, 8 * 3, 3), 'must output coordinates'
def test_extra_msa():
    model = Alphafold2(dim=128,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       predict_coords=True)

    seq = torch.randint(0, 21, (2, 4))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 4))
    msa_mask = torch.ones_like(msa).bool()

    extra_msa = torch.randint(0, 21, (2, 5, 4))
    extra_msa_mask = torch.ones_like(extra_msa).bool()

    coords = model(seq,
                   msa,
                   mask=mask,
                   msa_mask=msa_mask,
                   extra_msa=extra_msa,
                   extra_msa_mask=extra_msa_mask)
    assert True
Exemple #25
0
def test_coords_backbone_with_cbeta():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       atoms='backbone-with-cbeta',
                       predict_coords=True,
                       refine_coords=True,
                       structure_module_dim=1,
                       structure_module_depth=1,
                       structure_module_heads=1,
                       structure_module_dim_head=1,
                       structure_module_knn=2)

    seq = torch.randint(0, 21, (2, 8))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    assert coords.shape == (2, 8 * 4, 3), 'must output coordinates'
Exemple #26
0
def test_real_value_distance_with_coords():
    model = Alphafold2(dim=32,
                       depth=1,
                       heads=2,
                       dim_head=16,
                       predict_coords=True,
                       refine_coords=True,
                       predict_real_value_distances=True,
                       structure_module_dim=1,
                       structure_module_depth=1,
                       structure_module_heads=1,
                       structure_module_dim_head=1,
                       structure_module_knn=2)

    seq = torch.randint(0, 21, (2, 8))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    assert coords.shape == (2, 8 * 3, 3), 'must output coordinates'
Exemple #27
0
def test_coords_En():
    model = Alphafold2(dim=256,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       structure_module_type="en",
                       predict_coords=True,
                       refine_coords=True)

    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)
    # get masks : cloud is all points in prot. chain is all for which we have labels
    cloud_mask = scn_cloud_mask(seq, boolean=True)
    flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
    chain_mask = (mask.unsqueeze(-1) * cloud_mask)
    flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

    assert True
def test_coords_se3_backwards():
    model = Alphafold2(dim=256,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       predict_coords=True,
                       num_backbone_atoms=3,
                       structure_module_dim=1,
                       structure_module_depth=1,
                       structure_module_heads=1,
                       structure_module_dim_head=1,
                       structure_module_knn=1)

    seq = torch.randint(0, 21, (2, 8))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    coords.sum().backward()
    assert True, 'must be able to go backwards through MDS and center distogram'
def test_coords_se3_with_global_nodes():
    model = Alphafold2(dim=32,
                       depth=2,
                       heads=2,
                       dim_head=32,
                       predict_coords=True,
                       num_backbone_atoms=3,
                       structure_module_dim=1,
                       structure_module_depth=1,
                       structure_module_heads=1,
                       structure_module_dim_head=1,
                       structure_module_knn=2,
                       structure_num_global_nodes=2)

    seq = torch.randint(0, 21, (2, 8))
    mask = torch.ones_like(seq).bool()

    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    coords = model(seq, msa, mask=mask, msa_mask=msa_mask)

    assert coords.shape == (2, 8 * 14, 3), 'must output coordinates'
Exemple #30
0
def test_inter_msa_self_attn():
    model = Alphafold2(
        dim = 256,
        depth = 2,
        heads = 8,
        dim_head = 64,
        reversible = True,
        inter_msa_self_attn = False  # turn off attention across MSA sequences
    )

    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 64))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    distogram = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    distogram.sum().backward()
    assert True