def forward(self, decoys: DecoyBatch) -> DecoyBatch: residue_embeddings = super(ResidueEmbedding, self).forward(decoys.residues) return decoys.evolve(residues=None, node_features=torch.cat( (residue_embeddings, decoys.node_features), dim=1))
def __call__(self, decoys: DecoyBatch): decoys = decoys.evolve( senders=torch.cat((decoys.senders, decoys.receivers), dim=0), receivers=torch.cat((decoys.receivers, decoys.senders), dim=0), edge_features=decoys.edge_features.repeat(2, 1), num_edges_by_graph=decoys.num_edges_by_graph * 2, edge_index_by_graph=decoys.edge_index_by_graph.repeat(2)) return decoys
def forward(self, decoys: DecoyBatch): # Distances are encoded as simple scalars decoys = decoys.evolve( distances=None, edge_features=torch.cat( (decoys.distances[:, None], decoys.edge_features), dim=1), ) return decoys
def forward(self, decoys: DecoyBatch): # Distances are encoded using a equally spaced RBF kernels with unit variance distances_rbf = torch.exp(-(decoys.distances[:, None] - self.rbf_centers[None, :])**2) decoys = decoys.evolve( distances=None, edge_features=torch.cat((distances_rbf, decoys.edge_features), dim=1), ) return decoys
def forward(self, decoys: DecoyBatch): decoys = self.preprocessing(decoys) decoys = self.encoder(decoys) decoys = self.layers(decoys) decoys = self.readout(decoys) return decoys.evolve(node_features=decoys.node_features, num_edges_by_graph=None, edge_index_by_graph=None, edge_features=None, global_features=decoys.global_features, senders=None, receivers=None)
def forward(self, decoys: DecoyBatch): # separation = decoys.receivers - graphs.senders - 1 # separation_cls = searchsorted(separation, self.bins, side='right') - 1 separation = (decoys.senders - decoys.receivers + 1).float().unsqueeze_(0) separation_cls = (self.bins.numel() - 1) - searchsorted( self.bins, separation).squeeze_(0).long() separation_onehot = torch.zeros(decoys.num_edges, self.bins.numel(), device=decoys.senders.device) separation_onehot.scatter_(value=1., index=separation_cls.unsqueeze_(1), dim=1) decoys = decoys.evolve(edge_features=torch.cat( (decoys.edge_features, separation_onehot), dim=1), ) return decoys
def forward(self, decoys: DecoyBatch): separation = decoys.receivers - decoys.senders - 1 decoys = decoys.evolve(edge_features=torch.cat( (decoys.edge_features, separation[:, None].float()), dim=1), ) return decoys