def forward(self, z_n, z_g, graph: dgl.DGLGraph) -> Tensor:
        '''
        Args:
            z_g: Tensor of shape [n_graphs, z_dim].
            z_n: Tensor of shape [n_nodes, z_dim].
            batch: Tensor of shape [n_graphs].
        '''
        # TODO: this doesn not work yet
        raise NotImplementedError('not done')
        device = z_g.device
        num_graphs = z_g.shape[0]
        num_nodes = z_n.shape[0]
        node_indices = torch.cumsum(graph.batch_num_nodes(), dim=0)

        pos_mask = torch.zeros((num_nodes, num_graphs)).to(device)
        for graph_idx, nodes in enumerate(node_indices):
            pos_mask[range(nodes, node_indices[graph_idx + 1])][graph_idx] = 1.
        neg_mask = 1 - pos_mask

        d_prime = torch.matmul(z_n, z_g.t())

        E_pos = get_expectation(d_prime * pos_mask, positive=True).sum()
        E_pos = E_pos / num_nodes
        E_neg = get_expectation(d_prime * neg_mask, positive=False).sum()
        E_neg = E_neg / (num_nodes * (num_graphs - 1))
        return E_neg - E_pos
Example #2
0
    def forward(self, bg: dgl.DGLGraph, spatial_features: torch.Tensor,
                temporal_features: torch.Tensor,
                external_features: torch.Tensor):
        """
        get predictions
        :param bg: batched graphs,
             with the total number of nodes is `node_num`,
             including `batch_size` disconnected subgraphs
        :param spatial_features: shape [node_num, F_1]
        :param temporal_features: shape [node_num, F_2, T]
        :param external_features: shape [batch_size, F_3]
        :return: a tensor, shape [batch_size], with the prediction results for each graphs
        """

        if get_attribute("use_SBlock"):
            # shape [nodes * 10]
            s_out = self.spatial_gcn(bg,
                                     self.spatial_embedding(spatial_features))

        else:
            # remove spatial layer
            s_out = self.replace_spatial_gcn(spatial_features)

        if get_attribute("use_STBlock"):
            # temporal_embeddings of shape [node_num, 20, T_in]
            temporal_embeddings = self.temporal_embedding(
                bg, temporal_features)
        else:
            # remove temporal layer
            temporal_embeddings = torch.transpose(
                self.replace_temporal_layer(
                    torch.transpose(temporal_features, -1, -2)), -1, -2)

        # t_out of shape [1, node_num, 10]
        # _, (t_out, _) = self.temporal_agg(torch.transpose(temporal_embeddings, -1, -2))
        t_out = self.temporal_agg(temporal_embeddings)
        t_out.squeeze_()

        if get_attribute("use_Embedding"):
            e_out = self.external_embedding(external_features)
        else:
            # remove external embedding layer
            e_out = external_features

        try:
            nums_nodes, id = bg.batch_num_nodes(), 0
        except:
            nums_nodes, id = bg.batch_num_nodes, 0
        s_features, t_features = list(), list()
        for num_nodes in nums_nodes:
            s_features.append(s_out[id])
            t_features.append(t_out[id])
            id += num_nodes

        s_features = torch.stack(s_features)
        t_features = torch.stack(t_features)

        # torch.cat((x, y), -1), x: 2 * 3, y: 2 * 5, result: 2 * 8
        return self.output_layer(torch.cat((s_features, t_features, e_out),
                                           -1))
Example #3
0
    def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):
        score = self.score_layer(graph, feature).squeeze()
        perm, next_batch_num_nodes = topk(
            score, self.ratio, get_batch_id(graph.batch_num_nodes()),
            graph.batch_num_nodes())
        feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
        graph = dgl.node_subgraph(graph, perm)

        # node_subgraph currently does not support batch-graph,
        # the 'batch_num_nodes' of the result subgraph is None.
        # So we manually set the 'batch_num_nodes' here.
        # Since global pooling has nothing to do with 'batch_num_edges',
        # we can leave it to be None or unchanged.
        graph.set_batch_num_nodes(next_batch_num_nodes)

        return graph, feature, perm
Example #4
0
 def forward(  # type: ignore
     self,
     batched_trees: dgl.DGLGraph,
     output_length: int,
     target_sequence: torch.Tensor = None,
 ) -> torch.Tensor:
     batched_trees.ndata["x"] = self._embedding(batched_trees)
     encoded_nodes = self._encoder(batched_trees)
     output_logits = self._decoder(encoded_nodes,
                                   batched_trees.batch_num_nodes(),
                                   output_length, target_sequence)
     return output_logits
Example #5
0
    def forward(self,
                graph: DGLGraph,
                feat: Tensor,
                select_idx: Tensor,
                non_select_idx: Optional[Tensor] = None,
                scores: Optional[Tensor] = None,
                pool_graph=False):
        """
        Description
        -----------
        Perform graph pooling.

        Parameters
        ----------
        graph : dgl.DGLGraph
            The input graph
        feat : torch.Tensor
            The input node feature
        select_idx : torch.Tensor
            The index in fine graph of node from
            coarse graph, this is obtained from
            previous graph pooling layers. 
        non_select_idx : torch.Tensor, optional
            The index that not included in output graph.
            default: :obj:`None`
        scores : torch.Tensor, optional
            Scores for nodes used for pooling and scaling.
            default: :obj:`None`
        pool_graph : bool, optional
            Whether perform graph pooling on graph topology.
            default: :obj:`False`
        """
        if self.use_gcn:
            feat = self.down_sample_gcn(graph, feat)

        feat = feat[select_idx]
        if scores is not None:
            feat = feat * scores.unsqueeze(-1)

        if pool_graph:
            num_node_batch = graph.batch_num_nodes()
            graph = dgl.node_subgraph(graph, select_idx)
            graph.set_batch_num_nodes(num_node_batch)
            return feat, graph
        else:
            return feat
Example #6
0
    def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None):
        # top-k pool first
        if e_feat is None:
            e_feat = torch.ones((graph.number_of_edges(), ),
                                dtype=feat.dtype,
                                device=feat.device)
        batch_num_nodes = graph.batch_num_nodes()
        x_score = self.calc_info_score(graph, feat, e_feat)
        perm, next_batch_num_nodes = topk(x_score, self.ratio,
                                          get_batch_id(batch_num_nodes),
                                          batch_num_nodes)
        feat = feat[perm]
        pool_graph = None
        if not self.sample or not self.sl:
            # pool graph
            graph.edata["e"] = e_feat
            pool_graph = dgl.node_subgraph(graph, perm)
            e_feat = pool_graph.edata.pop("e")
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        # no structure learning layer, directly return.
        if not self.sl:
            return pool_graph, feat, e_feat, perm

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each
            # pair of nodes is time consuming. To accelerate this process,
            # we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.

            # first build multi-hop graph
            row, col = graph.all_edges()
            num_nodes = graph.num_nodes()

            scipy_adj = scipy.sparse.coo_matrix(
                (e_feat.detach().cpu(),
                 (row.detach().cpu(), col.detach().cpu())),
                shape=(num_nodes, num_nodes))
            for _ in range(self.k_hop):
                two_hop = scipy_adj**2
                two_hop = two_hop * (1e-5 / two_hop.max())
                scipy_adj = two_hop + scipy_adj
            row, col = scipy_adj.nonzero()
            row = torch.tensor(row, dtype=torch.long, device=graph.device)
            col = torch.tensor(col, dtype=torch.long, device=graph.device)
            e_feat = torch.tensor(scipy_adj.data,
                                  dtype=torch.float,
                                  device=feat.device)

            # perform pooling on multi-hop graph
            mask = perm.new_full((num_nodes, ), -1)
            i = torch.arange(perm.size(0),
                             dtype=torch.long,
                             device=perm.device)
            mask[perm] = i
            row, col = mask[row], mask[col]
            mask = (row >= 0) & (col >= 0)
            row, col = row[mask], col[mask]
            e_feat = e_feat[mask]

            # add remaining self loops
            mask = row != col
            num_nodes = perm.size(0)  # num nodes after pool
            loop_index = torch.arange(0,
                                      num_nodes,
                                      dtype=row.dtype,
                                      device=row.device)
            inv_mask = ~mask
            loop_weight = torch.full((num_nodes, ),
                                     0,
                                     dtype=e_feat.dtype,
                                     device=e_feat.device)
            remaining_e_feat = e_feat[inv_mask]
            if remaining_e_feat.numel() > 0:
                loop_weight[row[inv_mask]] = remaining_e_feat
            e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)
            row, col = row[mask], col[mask]
            row = torch.cat([row, loop_index], dim=0)
            col = torch.cat([col, loop_index], dim=0)

            # attention scores
            weights = (torch.cat([feat[row], feat[col]], dim=1) *
                       self.att).sum(dim=-1)
            weights = F.leaky_relu(weights,
                                   self.negative_slop) + e_feat * self.lamb

            # sl and normalization
            sl_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(sl_graph, weights)
            else:
                weights = edge_softmax(sl_graph, weights)

            # get final graph
            mask = torch.abs(weights) > 0
            row, col, weights = row[mask], col[mask], weights[mask]
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)
            e_feat = weights

        else:
            # Learning the possible edge weights between each pair of
            # nodes in the pooled subgraph, relative slower.

            # construct complete graphs for all graph in the batch
            # use dense to build, then transform to sparse.
            # maybe there's more efficient way?
            batch_num_nodes = next_batch_num_nodes
            block_begin_idx = torch.cat([
                batch_num_nodes.new_zeros(1),
                batch_num_nodes.cumsum(dim=0)[:-1]
            ],
                                        dim=0)
            block_end_idx = batch_num_nodes.cumsum(dim=0)
            dense_adj = torch.zeros(
                (pool_graph.num_nodes(), pool_graph.num_nodes()),
                dtype=torch.float,
                device=feat.device)
            for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
                dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.
            row, col = torch.nonzero(dense_adj).t().contiguous()

            # compute weights for node-pairs
            weights = (torch.cat([feat[row], feat[col]], dim=1) *
                       self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            dense_adj[row, col] = weights

            # add pooled graph structure to weight matrix
            pool_row, pool_col = pool_graph.all_edges()
            dense_adj[pool_row, pool_col] += self.lamb * e_feat
            weights = dense_adj[row, col]
            del dense_adj
            torch.cuda.empty_cache()

            # edge softmax/sparsemax
            complete_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(complete_graph, weights)
            else:
                weights = edge_softmax(complete_graph, weights)

            # get new e_feat and graph structure, clean up.
            mask = torch.abs(weights) > 1e-9
            row, col, weights = row[mask], col[mask], weights[mask]
            e_feat = weights
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        return pool_graph, feat, e_feat, perm