def forward(self, graph: Graph) -> torch.Tensor: x = graph.x if self.improved and not hasattr(graph, "unet_improved"): row, col = graph.edge_index row = torch.cat( [row, torch.arange(0, x.shape[0], device=x.device)], dim=0) col = torch.cat( [col, torch.arange(0, x.shape[0], device=x.device)], dim=0) graph.edge_index = (row, col) graph["unet_improved"] = True graph.row_norm() with graph.local_graph(): if self.training and self.adj_dropout > 0: graph.edge_index, graph.edge_weight = dropout_adj( graph.edge_index, graph.edge_weight, self.adj_dropout) x = F.dropout(x, p=self.n_dropout, training=self.training) h = self.in_gcn(graph, x) h = self.act(h) h_list = self.unet(graph, h) h = h_list[-1] h = F.dropout(h, p=self.n_dropout, training=self.training) return self.out_gcn(graph, h)
def loss(self, data: Graph, split="train"): if split == "train": mask = data.train_mask elif split == "val": mask = data.val_mask else: mask = data.test_mask edge_index, edge_types = data.edge_index[:, mask], data.edge_attr[mask] self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( edge_index, edge_types, self.edge_set, self.sampling_rate, self.num_rels, label_smoothing=self.lbl_smooth, num_entities=self.num_entities, ) with data.local_graph(): data.edge_index = batch_edges data.edge_attr = batch_attr node_embed, rel_embed = self.forward(data) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (self.cache_index == sampled_nodes).any() loss_n = self._loss(node_embed[reindexed_edges[0]], node_embed[reindexed_edges[1]], rel_embed[rels], labels) loss_r = self.penalty * self._regularization([self.emb(sampled_nodes), rel_embed]) return loss_n + loss_r
def loss(self, data: Graph, scoring): row, col = data.edge_index edge_types = data.edge_attr edge_index = torch.stack([row, col]) self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( (row, col), edge_types, self.edge_set, self.sampling_rate, self.num_rels, label_smoothing=self.lbl_smooth, num_entities=self.num_entities, ) with data.local_graph(): data.edge_index = batch_edges data.edge_attr = batch_attr node_embed, rel_embed = self.forward(data) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (self.cache_index == sampled_nodes).any() loss_n = self._loss(node_embed[reindexed_edges[0]], node_embed[reindexed_edges[1]], rel_embed[rels], labels, scoring) loss_r = self.penalty * self._regularization( [self.emb(sampled_nodes), rel_embed]) return loss_n + loss_r
def forward(self, graph: Graph) -> torch.Tensor: x = graph.x if self.improved and not hasattr(graph, "unet_improved"): self_loop = torch.stack([torch.arange(0, x.shape[0])] * 2, dim=0).to(x.device) graph.edge_index = torch.cat([graph.edge_index, self_loop], dim=1) graph["unet_improved"] = True graph.row_norm() with graph.local_graph(): if self.training and self.adj_dropout > 0: graph.edge_index, graph.edge_weight = dropout_adj(graph.edge_index, graph.edge_weight, self.adj_dropout) x = F.dropout(x, p=self.n_dropout, training=self.training) h = self.in_gcn(graph, x) h = self.act(h) h_list = self.unet(graph, h) h = h_list[-1] h = F.dropout(h, p=self.n_dropout, training=self.training) return self.out_gcn(graph, h)
def read_triplet_data(folder): filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"] count = 0 edge_index = [] edge_attr = [] count_list = [] triples = [] num_entities = 0 num_relations = 0 entity_dic = {} relation_dic = {} for filename in filenames: with open(osp.join(folder, filename), "r") as f: _ = int(f.readline().strip()) if "train" in filename: train_start_idx = len(triples) elif "valid" in filename: valid_start_idx = len(triples) elif "test" in filename: test_start_idx = len(triples) for line in f: items = line.strip().split() edge_index.append([int(items[0]), int(items[1])]) edge_attr.append(int(items[2])) triples.append((int(items[0]), int(items[2]), int(items[1]))) if items[0] not in entity_dic: entity_dic[items[0]] = num_entities num_entities += 1 if items[1] not in entity_dic: entity_dic[items[1]] = num_entities num_entities += 1 if items[2] not in relation_dic: relation_dic[items[2]] = num_relations num_relations += 1 count += 1 count_list.append(count) edge_index = torch.LongTensor(edge_index).t() edge_attr = torch.LongTensor(edge_attr) data = Graph() data.edge_index = edge_index data.edge_attr = edge_attr def generate_mask(start, end): mask = torch.BoolTensor(count) mask[:] = False mask[start:end] = True return mask data.train_mask = generate_mask(0, count_list[0]) data.val_mask = generate_mask(count_list[0], count_list[1]) data.test_mask = generate_mask(count_list[1], count_list[2]) return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations
def prop( self, graph: Graph, x: torch.Tensor, drop_feature_rate: float = 0.0, drop_edge_rate: float = 0.0, ): x = dropout_features(x, drop_feature_rate) with graph.local_graph(): graph.edge_index, graph.edge_weight = dropout_adj( graph.edge_index, graph.edge_weight, drop_edge_rate) return self.model.forward(graph, x)
def drop_adj(self, graph: Graph, drop_rate: float = 0.5): if drop_rate < 0.0 or drop_rate > 1.0: raise ValueError("Dropout probability has to be between 0 and 1, " "but got {}".format(drop_rate)) if not self.training: return graph num_edges = graph.num_edges mask = torch.full((num_edges,), 1 - drop_rate, dtype=torch.float) mask = torch.bernoulli(mask).to(torch.bool) edge_index = graph.edge_index edge_weight = graph.edge_weight graph.edge_index = edge_index[:, mask] graph.edge_weight = edge_weight[mask] return graph
def _construct_propagation_matrix(self, sample_adj, sample_id, num_neighbors): row, col = sample_adj.edge_index value = sample_adj.edge_weight """add self connection""" num_row = row.max() + 1 row = torch.cat([torch.arange(0, num_row).long(), row], dim=0) col = torch.cat([torch.arange(0, num_row).long(), col], dim=0) value = torch.cat([self.diag[sample_id[:num_row]], value], dim=0) value = value * self.degree[sample_id[row]] / num_neighbors new_graph = Graph() new_graph.edge_index = torch.stack([row, col]) new_graph.edge_weight = value return new_graph
def read_gtn_data(self, folder): edges = pickle.load(open(osp.join(folder, "edges.pkl"), "rb")) labels = pickle.load(open(osp.join(folder, "labels.pkl"), "rb")) node_features = pickle.load( open(osp.join(folder, "node_features.pkl"), "rb")) data = Graph() data.x = torch.from_numpy(node_features).type(torch.FloatTensor) num_nodes = edges[0].shape[0] node_type = np.zeros((num_nodes), dtype=int) assert len(edges) == 4 assert len(edges[0].nonzero()) == 2 node_type[edges[0].nonzero()[0]] = 0 node_type[edges[0].nonzero()[1]] = 1 node_type[edges[1].nonzero()[0]] = 1 node_type[edges[1].nonzero()[1]] = 0 node_type[edges[2].nonzero()[0]] = 0 node_type[edges[2].nonzero()[1]] = 2 node_type[edges[3].nonzero()[0]] = 2 node_type[edges[3].nonzero()[1]] = 0 print(node_type) data.pos = torch.from_numpy(node_type) edge_list = [] for i, edge in enumerate(edges): edge_tmp = torch.from_numpy( np.vstack((edge.nonzero()[0], edge.nonzero()[1]))).type(torch.LongTensor) edge_list.append(edge_tmp) data.edge_index = torch.cat(edge_list, 1) A = [] for i, edge in enumerate(edges): edge_tmp = torch.from_numpy( np.vstack((edge.nonzero()[0], edge.nonzero()[1]))).type(torch.LongTensor) value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) A.append((edge_tmp, value_tmp)) edge_tmp = torch.stack( (torch.arange(0, num_nodes), torch.arange(0, num_nodes))).type(torch.LongTensor) value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) A.append((edge_tmp, value_tmp)) data.adj = A data.train_node = torch.from_numpy(np.array(labels[0])[:, 0]).type( torch.LongTensor) data.train_target = torch.from_numpy(np.array(labels[0])[:, 1]).type( torch.LongTensor) data.valid_node = torch.from_numpy(np.array(labels[1])[:, 0]).type( torch.LongTensor) data.valid_target = torch.from_numpy(np.array(labels[1])[:, 1]).type( torch.LongTensor) data.test_node = torch.from_numpy(np.array(labels[2])[:, 0]).type( torch.LongTensor) data.test_target = torch.from_numpy(np.array(labels[2])[:, 1]).type( torch.LongTensor) y = np.zeros((num_nodes), dtype=int) x_index = torch.cat((data.train_node, data.valid_node, data.test_node)) y_index = torch.cat( (data.train_target, data.valid_target, data.test_target)) y[x_index.numpy()] = y_index.numpy() data.y = torch.from_numpy(y) self.data = data