예제 #1
0
 def predict(self, edge_index, edge_types):
     indices = torch.arange(0, self.num_entities).to(edge_index.device)
     x = self.emb(indices)
     # edge_index, edge_types = self.add_reverse_edges(edge_index, edge_types)
     node_embed, rel_embed = self.model(x, edge_index, edge_types)
     mrr, hits = cal_mrr(node_embed, rel_embed, edge_index, edge_types, scoring=self.scoring, protocol="raw", batch_size=500, hits=[1, 3, 10])
     return mrr, hits
예제 #2
0
파일: rgcn.py 프로젝트: xssstory/cogdl
 def predict(self, edge_index, edge_type):
     indices = torch.arange(0, self.num_nodes).to(edge_index.device)
     x = self.emb(indices)
     output = self.model(x, edge_index, edge_type)
     mrr, hits = cal_mrr(output,
                         self.rel_weight.weight,
                         edge_index,
                         edge_type,
                         scoring=self.scoring,
                         protocol="raw",
                         batch_size=500,
                         hits=[1, 3, 10])
     return mrr, hits
예제 #3
0
파일: rgcn.py 프로젝트: huaxz1986/cogdl
 def predict(self, graph):
     device = next(self.parameters()).device
     indices = torch.arange(0, self.num_nodes).to(device)
     x = self.emb(indices)
     output = self.model(graph, x)
     mrr, hits = cal_mrr(
         output,
         self.rel_weight.weight,
         graph.edge_index,
         graph.edge_attr,
         scoring=self.scoring,
         protocol="raw",
         batch_size=500,
         hits=[1, 3, 10],
     )
     return mrr, hits
예제 #4
0
파일: compgcn.py 프로젝트: huaxz1986/cogdl
 def predict(self, graph):
     device = next(self.parameters()).device
     indices = torch.arange(0, self.num_entities).to(device)
     x = self.emb(indices)
     # edge_index, edge_types = self.add_reverse_edges(edge_index, edge_types)
     node_embed, rel_embed = self.model(graph, x)
     edge_index, edge_types = graph.edge_index, graph.edge_attr
     mrr, hits = cal_mrr(
         node_embed,
         rel_embed,
         edge_index,
         edge_types,
         scoring=self.scoring,
         protocol="raw",
         batch_size=500,
         hits=[1, 3, 10],
     )
     return mrr, hits