Beispiel #1
0
def test_transformer_with_continuous_edges():
    model = SE3Transformer(dim=64,
                           depth=1,
                           attend_self=True,
                           num_degrees=2,
                           output_degrees=2,
                           edge_dim=34)

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask = torch.ones(1, 32).bool()

    pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))

    edges = fourier_encode(pairwise_continuous_values,
                           num_encodings=8,
                           include_self=True)

    out = model(feats, coors, mask, edges=edges, return_type=1)
    assert True
    def forward(self,
                seq,
                msa=None,
                mask=None,
                msa_mask=None,
                templates_seq=None,
                templates_dist=None,
                templates_mask=None,
                templates_coors=None,
                templates_sidechains=None,
                embedds=None,
                return_trunk=False,
                return_confidence=False,
                use_eigen_mds=False):
        n, device = seq.shape[1], seq.device
        n_range = torch.arange(n, device=device)

        # unpack (AA_code, atom_pos)

        if isinstance(seq, (list, tuple)):
            seq, seq_pos = seq

        # embed main sequence

        x = self.token_emb(seq)

        # outer sum

        x = rearrange(x, 'b i d -> b () i () d') + rearrange(
            x, 'b j d-> b () () j d')  # create pair-wise residue embeds
        x_mask = rearrange(mask, 'b i -> b () i ()') + rearrange(
            mask, 'b j -> b () () j') if exists(mask) else None

        # embed multiple sequence alignment (msa)

        m = None
        msa_shape = None
        if exists(msa):
            m = self.token_emb(msa)

            msa_shape = m.shape
            m = rearrange(m, 'b m n d -> b (m n) d')

            # get msa_mask to all ones if none was passed
            msa_mask = default(msa_mask, torch.ones_like(msa).bool())

        elif exists(embedds):
            m = self.embedd_project(embedds)

            msa_shape = m.shape
            m = rearrange(m, 'b m n d -> b (m n) d')

            # get msa_mask to all ones if none was passed
            msa_mask = default(msa_mask,
                               torch.ones_like(embedds[..., -1]).bool())

        if exists(msa_mask):
            msa_mask = rearrange(msa_mask, 'b m n -> b (m n)')

        # embed templates, if present

        if exists(templates_seq):
            assert exists(
                templates_coors
            ), 'template residue coordinates must be supplied `templates_coors`'
            _, num_templates, *_ = templates_seq.shape

            if not exists(templates_dist):
                templates_dist = get_bucketed_distance_matrix(
                    templates_coors, templates_mask,
                    constants.DISTOGRAM_BUCKETS)

            # embed template

            t_seq = self.token_emb(templates_seq)

            # if sidechain information is present
            # color the residue embeddings with the sidechain type 1 features
            # todo (make efficient)

            if exists(templates_sidechains):
                if self.use_se3_transformer:
                    t_seq = self.template_sidechain_emb(t_seq,
                                                        templates_sidechains,
                                                        templates_coors,
                                                        mask=templates_mask)
                else:
                    shape = t_seq.shape
                    t_seq = rearrange(t_seq, 'b t n d -> (b t) n d')
                    templates_coors = rearrange(templates_coors,
                                                'b t n c -> (b t) n c')
                    en_mask = rearrange(templates_mask, 'b t n -> (b t) n')

                    t_seq, _ = self.template_sidechain_emb(t_seq,
                                                           templates_coors,
                                                           mask=en_mask)

                    t_seq = t_seq.reshape(*shape)

            # embed template distances

            t_dist = self.template_dist_emb(templates_dist)

            t_seq = rearrange(t_seq, 'b t i d -> b t i () d') + rearrange(
                t_seq, 'b t j d -> b t () j d')
            t = t_seq + t_dist

            # template pos emb

            template_num_pos_emb = self.template_num_pos_emb(
                torch.arange(num_templates, device=device))
            t += rearrange(template_num_pos_emb, 't d-> () t () () d')

            assert t.shape[-2:] == x.shape[-2:]

            x = torch.cat((x, t), dim=1)

            if exists(templates_mask):
                t_mask = rearrange(templates_mask,
                                   'b t i -> b t i ()') * rearrange(
                                       templates_mask, 'b t j -> b t () j')
                x_mask = torch.cat((x_mask, t_mask), dim=1)

        # flatten

        seq_shape = x.shape
        x = rearrange(x, 'b t i j d -> b (t i j) d')
        x_mask = rearrange(x_mask,
                           'b t i j -> b (t i j)') if exists(mask) else None

        # trunk

        x, m = self.net(x,
                        m,
                        seq_shape,
                        msa_shape,
                        mask=x_mask,
                        msa_mask=msa_mask)

        # remove templates, if present

        x = x.view(seq_shape)
        x = x[:, 0]

        # calculate theta and phi before symmetrization

        if self.predict_angles:
            theta_logits = self.to_prob_theta(x)
            phi_logits = self.to_prob_phi(x)

        # embeds to distogram

        trunk_embeds = (x +
                        rearrange(x, 'b i j d -> b j i d')) * 0.5  # symmetrize
        distance_pred = self.to_distogram_logits(trunk_embeds)

        # determine angles, if specified

        ret = distance_pred

        if self.predict_angles:
            omega_input = trunk_embeds if self.symmetrize_omega else x
            omega_logits = self.to_prob_omega(omega_input)
            ret = Logits(distance_pred, theta_logits, phi_logits, omega_logits)

        if not self.predict_coords or return_trunk:
            return ret

        # prepare mask for backbone coordinates

        assert self.num_backbone_atoms > 1, 'must constitute to at least 3 atomic coordinates for backbone'

        N_mask, CA_mask, C_mask = scn_backbone_mask(
            seq, boolean=True, n_aa=self.num_backbone_atoms)

        cloud_mask = scn_cloud_mask(seq, boolean=True)
        flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
        chain_mask = (mask.unsqueeze(-1) * cloud_mask)
        flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

        bb_flat_mask = rearrange(chain_mask[..., :self.num_backbone_atoms],
                                 'b l c -> b (l c)')
        bb_flat_mask_crossed = rearrange(bb_flat_mask,
                                         'b i -> b i ()') * rearrange(
                                             bb_flat_mask, 'b j -> b () j')

        # structural refinement

        if self.predict_real_value_distances:
            distances, distance_std = distance_pred.unbind(dim=-1)
            weights = (1 / (1 + distance_std)
                       )  # could also do a distance_std.sigmoid() here
        else:
            distances, weights = center_distogram_torch(distance_pred)

        # set unwanted atoms to weight=0 (like C-beta in glycine)
        weights.masked_fill_(torch.logical_not(bb_flat_mask_crossed), 0.)

        coords_3d, _ = MDScaling(
            distances,
            weights=weights if not use_eigen_mds else None,
            iters=self.mds_iters,
            fix_mirror=True,
            N_mask=N_mask,
            CA_mask=CA_mask,
            C_mask=C_mask)
        coords = rearrange(coords_3d, 'b c n -> b n c')
        # will init all sidechain coords to cbeta if present else c_alpha
        coords = sidechain_container(coords,
                                     n_aa=self.num_backbone_atoms,
                                     cloud_mask=cloud_mask)
        coords = rearrange(coords, 'b n l d -> b (n l) d')
        atom_tokens = scn_atom_embedd(seq)  # not used for now, but could be

        num_atoms = cloud_mask.shape[-1]

        structure_embed = self.trunk_to_structure_dim(trunk_embeds)
        x = reduce(structure_embed, 'b i j d -> b i d', 'mean')
        x += self.structure_module_embeds(seq)
        x = repeat(x, 'b n d -> b n l d', l=num_atoms)
        x += self.atom_tokens_embed(atom_tokens)
        x = rearrange(x, 'b n l d -> b (n l) d')

        # derive edges from trunk -> equivariant network, if needed

        edges = None
        if exists(self.to_equivariant_net_edges):
            edges = self.to_equivariant_net_edges(trunk_embeds)
            edges = fourier_encode(
                edges,
                num_encodings=self.se3_edges_fourier_encodings,
                include_self=True)
            edges = repeat(edges,
                           'b i j d -> b (i l) (j m) d',
                           l=num_atoms,
                           m=num_atoms)
            edges = edges.double()

        # prepare float64 precision for equivariance

        original_dtype = coords.dtype
        x, coords = map(lambda t: t.double(), (x, coords))

        # derive adjacency matrix
        # todo - fix so Cbeta is connected correctly

        adj_idxs, adj_num = prot_covalent_bond(seq,
                                               adj_degree=1,
                                               cloud_mask=cloud_mask)
        adj_mat = adj_num.bool()

        # /adjacency mat calc - above should be pre-calculated and cached in a buffer

        with torch_default_dtype(torch.float64):
            for _ in range(self.structure_module_refinement_iters):
                x, coords = self.structure_module(x,
                                                  coords,
                                                  mask=flat_chain_mask,
                                                  adj_mat=adj_mat,
                                                  edges=edges)

        coords.type(original_dtype)

        if self.return_aux_logits:
            return coords, ret

        if return_confidence:
            return coords, self.lddt_linear(x.float())

        return coords
    def forward(
        self,
        inp,
        edge_info,
        rel_dist = None,
        basis = None
    ):
        splits = self.splits
        neighbor_indices, neighbor_masks, edges = edge_info
        rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

        kernels = {}
        outputs = {}

        if self.fourier_encode_dist:
            rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)

        # split basis

        basis_keys = basis.keys()
        split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values()))))
        split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))

        # go through every permutation of input degree type to output degree type

        for degree_out in self.fiber_out.degrees:
            output = 0
            degree_out_key = str(degree_out)

            for degree_in, m_in in self.fiber_in:
                etype = f'({degree_in},{degree_out})'

                x = inp[str(degree_in)]

                x = batched_index_select(x, neighbor_indices, dim = 1)
                x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                kernel_fn = self.kernel_unary[etype]
                edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist

                output_chunk = None
                split_x = fast_split(x, splits, dim = 1)
                split_edge_features = fast_split(edge_features, splits, dim = 1)

                # process input, edges, and basis in chunks along the sequence dimension

                for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
                    kernel = kernel_fn(edge_features, basis = basis)
                    chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
                    output_chunk = safe_cat(output_chunk, chunk, dim = 1)

                output = output + output_chunk

            if self.pool:
                output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)

            leading_shape = x.shape[:2] if self.pool else x.shape[:3]
            output = output.view(*leading_shape, -1, to_order(degree_out))

            outputs[degree_out_key] = output

        if self.self_interaction:
            self_interact_out = self.self_interact(inp)
            outputs = self.self_interact_sum(outputs, self_interact_out)

        return outputs