def forward(self, z_n, z_g, graph: dgl.DGLGraph) -> Tensor: ''' Args: z_g: Tensor of shape [n_graphs, z_dim]. z_n: Tensor of shape [n_nodes, z_dim]. batch: Tensor of shape [n_graphs]. ''' # TODO: this doesn not work yet raise NotImplementedError('not done') device = z_g.device num_graphs = z_g.shape[0] num_nodes = z_n.shape[0] node_indices = torch.cumsum(graph.batch_num_nodes(), dim=0) pos_mask = torch.zeros((num_nodes, num_graphs)).to(device) for graph_idx, nodes in enumerate(node_indices): pos_mask[range(nodes, node_indices[graph_idx + 1])][graph_idx] = 1. neg_mask = 1 - pos_mask d_prime = torch.matmul(z_n, z_g.t()) E_pos = get_expectation(d_prime * pos_mask, positive=True).sum() E_pos = E_pos / num_nodes E_neg = get_expectation(d_prime * neg_mask, positive=False).sum() E_neg = E_neg / (num_nodes * (num_graphs - 1)) return E_neg - E_pos
def forward(self, bg: dgl.DGLGraph, spatial_features: torch.Tensor, temporal_features: torch.Tensor, external_features: torch.Tensor): """ get predictions :param bg: batched graphs, with the total number of nodes is `node_num`, including `batch_size` disconnected subgraphs :param spatial_features: shape [node_num, F_1] :param temporal_features: shape [node_num, F_2, T] :param external_features: shape [batch_size, F_3] :return: a tensor, shape [batch_size], with the prediction results for each graphs """ if get_attribute("use_SBlock"): # shape [nodes * 10] s_out = self.spatial_gcn(bg, self.spatial_embedding(spatial_features)) else: # remove spatial layer s_out = self.replace_spatial_gcn(spatial_features) if get_attribute("use_STBlock"): # temporal_embeddings of shape [node_num, 20, T_in] temporal_embeddings = self.temporal_embedding( bg, temporal_features) else: # remove temporal layer temporal_embeddings = torch.transpose( self.replace_temporal_layer( torch.transpose(temporal_features, -1, -2)), -1, -2) # t_out of shape [1, node_num, 10] # _, (t_out, _) = self.temporal_agg(torch.transpose(temporal_embeddings, -1, -2)) t_out = self.temporal_agg(temporal_embeddings) t_out.squeeze_() if get_attribute("use_Embedding"): e_out = self.external_embedding(external_features) else: # remove external embedding layer e_out = external_features try: nums_nodes, id = bg.batch_num_nodes(), 0 except: nums_nodes, id = bg.batch_num_nodes, 0 s_features, t_features = list(), list() for num_nodes in nums_nodes: s_features.append(s_out[id]) t_features.append(t_out[id]) id += num_nodes s_features = torch.stack(s_features) t_features = torch.stack(t_features) # torch.cat((x, y), -1), x: 2 * 3, y: 2 * 5, result: 2 * 8 return self.output_layer(torch.cat((s_features, t_features, e_out), -1))
def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor): score = self.score_layer(graph, feature).squeeze() perm, next_batch_num_nodes = topk( score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes()) feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) graph = dgl.node_subgraph(graph, perm) # node_subgraph currently does not support batch-graph, # the 'batch_num_nodes' of the result subgraph is None. # So we manually set the 'batch_num_nodes' here. # Since global pooling has nothing to do with 'batch_num_edges', # we can leave it to be None or unchanged. graph.set_batch_num_nodes(next_batch_num_nodes) return graph, feature, perm
def forward( # type: ignore self, batched_trees: dgl.DGLGraph, output_length: int, target_sequence: torch.Tensor = None, ) -> torch.Tensor: batched_trees.ndata["x"] = self._embedding(batched_trees) encoded_nodes = self._encoder(batched_trees) output_logits = self._decoder(encoded_nodes, batched_trees.batch_num_nodes(), output_length, target_sequence) return output_logits
def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor, non_select_idx: Optional[Tensor] = None, scores: Optional[Tensor] = None, pool_graph=False): """ Description ----------- Perform graph pooling. Parameters ---------- graph : dgl.DGLGraph The input graph feat : torch.Tensor The input node feature select_idx : torch.Tensor The index in fine graph of node from coarse graph, this is obtained from previous graph pooling layers. non_select_idx : torch.Tensor, optional The index that not included in output graph. default: :obj:`None` scores : torch.Tensor, optional Scores for nodes used for pooling and scaling. default: :obj:`None` pool_graph : bool, optional Whether perform graph pooling on graph topology. default: :obj:`False` """ if self.use_gcn: feat = self.down_sample_gcn(graph, feat) feat = feat[select_idx] if scores is not None: feat = feat * scores.unsqueeze(-1) if pool_graph: num_node_batch = graph.batch_num_nodes() graph = dgl.node_subgraph(graph, select_idx) graph.set_batch_num_nodes(num_node_batch) return feat, graph else: return feat
def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None): # top-k pool first if e_feat is None: e_feat = torch.ones((graph.number_of_edges(), ), dtype=feat.dtype, device=feat.device) batch_num_nodes = graph.batch_num_nodes() x_score = self.calc_info_score(graph, feat, e_feat) perm, next_batch_num_nodes = topk(x_score, self.ratio, get_batch_id(batch_num_nodes), batch_num_nodes) feat = feat[perm] pool_graph = None if not self.sample or not self.sl: # pool graph graph.edata["e"] = e_feat pool_graph = dgl.node_subgraph(graph, perm) e_feat = pool_graph.edata.pop("e") pool_graph.set_batch_num_nodes(next_batch_num_nodes) # no structure learning layer, directly return. if not self.sl: return pool_graph, feat, e_feat, perm # Structure Learning if self.sample: # A fast mode for large graphs. # In large graphs, learning the possible edge weights between each # pair of nodes is time consuming. To accelerate this process, # we sample it's K-Hop neighbors for each node and then learn the # edge weights between them. # first build multi-hop graph row, col = graph.all_edges() num_nodes = graph.num_nodes() scipy_adj = scipy.sparse.coo_matrix( (e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes)) for _ in range(self.k_hop): two_hop = scipy_adj**2 two_hop = two_hop * (1e-5 / two_hop.max()) scipy_adj = two_hop + scipy_adj row, col = scipy_adj.nonzero() row = torch.tensor(row, dtype=torch.long, device=graph.device) col = torch.tensor(col, dtype=torch.long, device=graph.device) e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device) # perform pooling on multi-hop graph mask = perm.new_full((num_nodes, ), -1) i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device) mask[perm] = i row, col = mask[row], mask[col] mask = (row >= 0) & (col >= 0) row, col = row[mask], col[mask] e_feat = e_feat[mask] # add remaining self loops mask = row != col num_nodes = perm.size(0) # num nodes after pool loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device) inv_mask = ~mask loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device) remaining_e_feat = e_feat[inv_mask] if remaining_e_feat.numel() > 0: loop_weight[row[inv_mask]] = remaining_e_feat e_feat = torch.cat([e_feat[mask], loop_weight], dim=0) row, col = row[mask], col[mask] row = torch.cat([row, loop_index], dim=0) col = torch.cat([col, loop_index], dim=0) # attention scores weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb # sl and normalization sl_graph = dgl.graph((row, col)) if self.sparse: weights = edge_sparsemax(sl_graph, weights) else: weights = edge_softmax(sl_graph, weights) # get final graph mask = torch.abs(weights) > 0 row, col, weights = row[mask], col[mask], weights[mask] pool_graph = dgl.graph((row, col)) pool_graph.set_batch_num_nodes(next_batch_num_nodes) e_feat = weights else: # Learning the possible edge weights between each pair of # nodes in the pooled subgraph, relative slower. # construct complete graphs for all graph in the batch # use dense to build, then transform to sparse. # maybe there's more efficient way? batch_num_nodes = next_batch_num_nodes block_begin_idx = torch.cat([ batch_num_nodes.new_zeros(1), batch_num_nodes.cumsum(dim=0)[:-1] ], dim=0) block_end_idx = batch_num_nodes.cumsum(dim=0) dense_adj = torch.zeros( (pool_graph.num_nodes(), pool_graph.num_nodes()), dtype=torch.float, device=feat.device) for idx_b, idx_e in zip(block_begin_idx, block_end_idx): dense_adj[idx_b:idx_e, idx_b:idx_e] = 1. row, col = torch.nonzero(dense_adj).t().contiguous() # compute weights for node-pairs weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) dense_adj[row, col] = weights # add pooled graph structure to weight matrix pool_row, pool_col = pool_graph.all_edges() dense_adj[pool_row, pool_col] += self.lamb * e_feat weights = dense_adj[row, col] del dense_adj torch.cuda.empty_cache() # edge softmax/sparsemax complete_graph = dgl.graph((row, col)) if self.sparse: weights = edge_sparsemax(complete_graph, weights) else: weights = edge_softmax(complete_graph, weights) # get new e_feat and graph structure, clean up. mask = torch.abs(weights) > 1e-9 row, col, weights = row[mask], col[mask], weights[mask] e_feat = weights pool_graph = dgl.graph((row, col)) pool_graph.set_batch_num_nodes(next_batch_num_nodes) return pool_graph, feat, e_feat, perm