Пример #1
0
 def forward(self, decoys: DecoyBatch) -> DecoyBatch:
     residue_embeddings = super(ResidueEmbedding,
                                self).forward(decoys.residues)
     return decoys.evolve(residues=None,
                          node_features=torch.cat(
                              (residue_embeddings, decoys.node_features),
                              dim=1))
Пример #2
0
 def __call__(self, decoys: DecoyBatch):
     decoys = decoys.evolve(
         senders=torch.cat((decoys.senders, decoys.receivers), dim=0),
         receivers=torch.cat((decoys.receivers, decoys.senders), dim=0),
         edge_features=decoys.edge_features.repeat(2, 1),
         num_edges_by_graph=decoys.num_edges_by_graph * 2,
         edge_index_by_graph=decoys.edge_index_by_graph.repeat(2))
     return decoys
Пример #3
0
    def forward(self, decoys: DecoyBatch):
        # Distances are encoded as simple scalars
        decoys = decoys.evolve(
            distances=None,
            edge_features=torch.cat(
                (decoys.distances[:, None], decoys.edge_features), dim=1),
        )

        return decoys
Пример #4
0
    def forward(self, decoys: DecoyBatch):
        # Distances are encoded using a equally spaced RBF kernels with unit variance
        distances_rbf = torch.exp(-(decoys.distances[:, None] -
                                    self.rbf_centers[None, :])**2)

        decoys = decoys.evolve(
            distances=None,
            edge_features=torch.cat((distances_rbf, decoys.edge_features),
                                    dim=1),
        )

        return decoys
Пример #5
0
    def forward(self, decoys: DecoyBatch):
        decoys = self.preprocessing(decoys)
        decoys = self.encoder(decoys)
        decoys = self.layers(decoys)
        decoys = self.readout(decoys)

        return decoys.evolve(node_features=decoys.node_features,
                             num_edges_by_graph=None,
                             edge_index_by_graph=None,
                             edge_features=None,
                             global_features=decoys.global_features,
                             senders=None,
                             receivers=None)
Пример #6
0
    def forward(self, decoys: DecoyBatch):
        # separation = decoys.receivers - graphs.senders - 1
        # separation_cls = searchsorted(separation, self.bins, side='right') - 1

        separation = (decoys.senders - decoys.receivers +
                      1).float().unsqueeze_(0)
        separation_cls = (self.bins.numel() - 1) - searchsorted(
            self.bins, separation).squeeze_(0).long()

        separation_onehot = torch.zeros(decoys.num_edges,
                                        self.bins.numel(),
                                        device=decoys.senders.device)
        separation_onehot.scatter_(value=1.,
                                   index=separation_cls.unsqueeze_(1),
                                   dim=1)

        decoys = decoys.evolve(edge_features=torch.cat(
            (decoys.edge_features, separation_onehot), dim=1), )

        return decoys
Пример #7
0
 def forward(self, decoys: DecoyBatch):
     separation = decoys.receivers - decoys.senders - 1
     decoys = decoys.evolve(edge_features=torch.cat(
         (decoys.edge_features, separation[:, None].float()), dim=1), )
     return decoys