def loss(self, data: Graph, split="train"): if split == "train": mask = data.train_mask elif split == "val": mask = data.val_mask else: mask = data.test_mask edge_index, edge_types = data.edge_index[:, mask], data.edge_attr[mask] self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( edge_index, edge_types, self.edge_set, self.sampling_rate, self.num_rels, label_smoothing=self.lbl_smooth, num_entities=self.num_entities, ) with data.local_graph(): data.edge_index = batch_edges data.edge_attr = batch_attr node_embed, rel_embed = self.forward(data) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (self.cache_index == sampled_nodes).any() loss_n = self._loss(node_embed[reindexed_edges[0]], node_embed[reindexed_edges[1]], rel_embed[rels], labels) loss_r = self.penalty * self._regularization([self.emb(sampled_nodes), rel_embed]) return loss_n + loss_r
def forward(self, graph: Graph) -> torch.Tensor: x = graph.x if self.improved and not hasattr(graph, "unet_improved"): row, col = graph.edge_index row = torch.cat( [row, torch.arange(0, x.shape[0], device=x.device)], dim=0) col = torch.cat( [col, torch.arange(0, x.shape[0], device=x.device)], dim=0) graph.edge_index = (row, col) graph["unet_improved"] = True graph.row_norm() with graph.local_graph(): if self.training and self.adj_dropout > 0: graph.edge_index, graph.edge_weight = dropout_adj( graph.edge_index, graph.edge_weight, self.adj_dropout) x = F.dropout(x, p=self.n_dropout, training=self.training) h = self.in_gcn(graph, x) h = self.act(h) h_list = self.unet(graph, h) h = h_list[-1] h = F.dropout(h, p=self.n_dropout, training=self.training) return self.out_gcn(graph, h)
def loss(self, data: Graph, scoring): row, col = data.edge_index edge_types = data.edge_attr edge_index = torch.stack([row, col]) self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( (row, col), edge_types, self.edge_set, self.sampling_rate, self.num_rels, label_smoothing=self.lbl_smooth, num_entities=self.num_entities, ) with data.local_graph(): data.edge_index = batch_edges data.edge_attr = batch_attr node_embed, rel_embed = self.forward(data) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (self.cache_index == sampled_nodes).any() loss_n = self._loss(node_embed[reindexed_edges[0]], node_embed[reindexed_edges[1]], rel_embed[rels], labels, scoring) loss_r = self.penalty * self._regularization( [self.emb(sampled_nodes), rel_embed]) return loss_n + loss_r
def prop( self, graph: Graph, x: torch.Tensor, drop_feature_rate: float = 0.0, drop_edge_rate: float = 0.0, ): x = self.drop_feature(x, drop_feature_rate) with graph.local_graph(): graph = self.drop_adj(graph, drop_edge_rate) return self.forward(graph, x)
def prop( self, graph: Graph, x: torch.Tensor, drop_feature_rate: float = 0.0, drop_edge_rate: float = 0.0, ): x = dropout_features(x, drop_feature_rate) with graph.local_graph(): graph.edge_index, graph.edge_weight = dropout_adj( graph.edge_index, graph.edge_weight, drop_edge_rate) return self.model.forward(graph, x)
def _add_undirected_graph_positional_embedding(g: Graph, hidden_size, retry=10): # We use eigenvectors of normalized graph laplacian as vertex features. # It could be viewed as a generalization of positional embedding in the # attention is all you need paper. # Recall that the eignvectors of normalized laplacian of a line graph are cos/sin functions. # See section 2.4 of http://www.cs.yale.edu/homes/spielman/561/2009/lect02-09.pdf n = g.num_nodes with g.local_graph(): g.sym_norm() adj = g.to_scipy_csr() laplacian = adj k = min(n - 2, hidden_size) x = eigen_decomposision(n, k, laplacian, hidden_size, retry) g.pos_undirected = x.float() return g
def forward(self, graph: Graph) -> torch.Tensor: x = graph.x if self.improved and not hasattr(graph, "unet_improved"): self_loop = torch.stack([torch.arange(0, x.shape[0])] * 2, dim=0).to(x.device) graph.edge_index = torch.cat([graph.edge_index, self_loop], dim=1) graph["unet_improved"] = True graph.row_norm() with graph.local_graph(): if self.training and self.adj_dropout > 0: graph.edge_index, graph.edge_weight = dropout_adj(graph.edge_index, graph.edge_weight, self.adj_dropout) x = F.dropout(x, p=self.n_dropout, training=self.training) h = self.in_gcn(graph, x) h = self.act(h) h_list = self.unet(graph, h) h = h_list[-1] h = F.dropout(h, p=self.n_dropout, training=self.training) return self.out_gcn(graph, h)