Exemplo n.º 1
0
    def forward(self, input_hidden, graphs: dgl.DGLGraph, batch_num_nodes=None):
        if batch_num_nodes is None:
            b_num_nodes = graphs.batch_num_nodes
        else:
            b_num_nodes = batch_num_nodes
        h_t = self.input_proj(input_hidden)
        # when there are no edges in the graph, there is nothing to do
        if graphs.number_of_edges() > 0:
            #give all the nodes an edges information about the current querry hidden state
            broadcasted_hn = dgl.broadcast_nodes(graphs, h_t)
            graphs.ndata['h_t'] = broadcasted_hn
            broadcasted_he = dgl.broadcast_edges(graphs, h_t)
            graphs.edata['h_t'] = broadcasted_he
            # create a copy of the node and edge states which will be updated for K iterations
            graphs.ndata['F_n_t'] = graphs.ndata['F_n']
            graphs.edata['F_e_t'] = graphs.edata['F_e']

            for _ in range(self.k_update_steps):
                graphs.ndata['s_n'] = self.object_score(torch.cat([graphs.ndata['h_t'], graphs.ndata['F_n_t']], dim=-1))
                graphs.send(message_func=self.io_attention_send)
                graphs.recv(reduce_func=self.io_attention_reduce)
                graphs.ndata['F_n_t'] = graphs.ndata['F_i_tplus1']
                if self.update_relations:
                    graphs.edata['F_e_t'] = graphs.edata['F_e_tplus1']

            io = torch.split(graphs.ndata['F_n_t'], split_size_or_sections=b_num_nodes)
        else:
            io = torch.split(graphs.ndata['F_n'], split_size_or_sections=b_num_nodes)
        io = pad_sequence(io, batch_first=True)
        io_mask = io.sum(dim=-1) != 0

        return io, io_mask
Exemplo n.º 2
0
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        """
        Parameters
        ----------
        g : DGLGraph or BatchedDGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.
        g_feats : float32 tensor of shape (G, N2)
            Input graph features. G for the number of graphs and N2 for the feature size.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        Returns
        -------
        float32 tensor of shape (G, N2)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        """
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)
            context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
            else:
                return self.gru(context, g_feats)
Exemplo n.º 3
0
    def forward(self, graph, node_feat, edge_feat):
        if self.virtual_node:
            virtual_emb = self.virtual_emb.weight.expand(graph.batch_size, -1)

        hn = self.node_encoder(node_feat)

        for layer in range(self.num_layers):

            if self.virtual_node:
                # messages from virtual nodes to graph nodes
                virtual_hn = dgl.broadcast_nodes(graph, virtual_emb)
                hn = hn + virtual_hn

            he = self.edge_encoders[layer](edge_feat)
            hn = self.conv_layers[layer](graph, hn, he)
            if layer != self.num_layers - 1:
                hn = F.relu(hn)
            hn = self.dropout(hn)

            if self.virtual_node and layer != self.num_layers - 1:
                # messages from graph nodes to virtual nodes
                virtual_emb_tmp = self.virtual_pool(graph, hn) + virtual_emb
                virtual_emb = self.mlp_virtual[layer](virtual_emb_tmp)
                virtual_emb = self.dropout(F.relu(virtual_emb))

        hg = self.pool(graph, hn)

        return self.pred(hg)
Exemplo n.º 4
0
    def forward(self, inputs, extra_inputs=None):
        KG_embeddings = super().forward(extra_inputs)

        uid, g = inputs
        iid = g.ndata['iid']  # (num_nodes,)
        feat_i = KG_embeddings['item'][iid]
        feat_u = KG_embeddings['user'][uid]
        feat = self.fc_i(feat_i) + dgl.broadcast_nodes(g, self.fc_u(feat_u))
        feat_i = self.PSE_layer(g, feat)
        sr = th.cat([feat_i, feat_u], dim=1)
        logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t()
        return logits
Exemplo n.º 5
0
Arquivo: lessr.py Projeto: lessr/lessr
 def forward(self, g, feat, last_nodes):
     with g.local_scope():
         if self.batch_norm is not None:
             feat = self.batch_norm(feat)
         feat_u = self.fc_u(feat)
         feat_v = self.fc_v(feat[last_nodes])
         feat_v = dgl.broadcast_nodes(g, feat_v)
         g.ndata['e'] = self.attn_e(th.sigmoid(feat_u + feat_v))
         alpha = dgl.softmax_nodes(g, 'e')
         g.ndata['w'] = feat * alpha
         rst = dgl.sum_nodes(g, 'w')
         rst = self.fc_out(rst)
         return rst
Exemplo n.º 6
0
    def forward(self, g, x, edge_attr):
        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(
            torch.zeros(g.batch_size).to(x.dtype).to(x.device))

        h_list = [self.atom_encoder(x)]
        batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device))
        for layer in range(self.num_layers):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]

            ### Message passing among graph nodes
            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layers - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding
                ### transform virtual nodes using MLP
                virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
                    virtualnode_embedding_temp)

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(
                        virtualnode_embedding_temp, self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(
                        virtualnode_embedding_temp, self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation
Exemplo n.º 7
0
def collate(samples):
    ''' collate function for building graph dataloader '''

    # generate batched graphs and labels
    graphs, targets = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    batched_targets = th.Tensor(targets)
    
    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
    
    batched_graph.ndata['graph_id'] = graph_id
    
    return batched_graph, batched_targets
Exemplo n.º 8
0
def test_broadcast_nodes():
    # test#1: basic
    g0 = dgl.DGLGraph(nx.path_graph(10))
    feat0 = F.randn((40, ))
    ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
    assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(3))
    g2 = dgl.DGLGraph()
    g3 = dgl.DGLGraph(nx.path_graph(12))
    bg = dgl.batch([g0, g1, g2, g3])
    feat1 = F.randn((40, ))
    feat2 = F.randn((40, ))
    feat3 = F.randn((40, ))
    ground_truth = F.stack(
        [feat0] * g0.number_of_nodes() +\
        [feat1] * g1.number_of_nodes() +\
        [feat2] * g2.number_of_nodes() +\
        [feat3] * g3.number_of_nodes(), 0
    )
    assert F.allclose(
        dgl.broadcast_nodes(bg, F.stack([feat0, feat1, feat2, feat3], 0)),
        ground_truth)
Exemplo n.º 9
0
 def forward(self, g, feat, last_nodes):
     if self.batch_norm is not None:
         feat = self.batch_norm(feat)
     feat = self.feat_drop(feat)
     feat_u = self.fc_u(feat)
     feat_v = self.fc_v(feat[last_nodes])
     feat_v = dgl.broadcast_nodes(g, feat_v)
     e = self.fc_e(th.sigmoid(feat_u + feat_v))
     alpha = F.segment.segment_softmax(g.batch_num_nodes(), e)
     feat_norm = feat * alpha
     rst = F.segment.segment_reduce(g.batch_num_nodes(), feat_norm, 'sum')
     if self.fc_out is not None:
         rst = self.fc_out(rst)
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
Exemplo n.º 10
0
def collate(samples):
    ''' collate function for building the graph dataloader'''
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata['graph_id'] = graph_id

    return batched_graph, batched_diff_graph, batched_labels
Exemplo n.º 11
0
 def forward(self, g, feat, last_nodes):
     if self.batch_norm is not None:
         feat = self.batch_norm(feat)
     if self.feat_drop is not None:
         feat = self.feat_drop(feat)
     feat_u = self.fc_u(feat)
     feat_v = self.fc_v(feat[last_nodes])
     feat_v = dgl.broadcast_nodes(g, feat_v)
     e = self.fc_e(th.sigmoid(feat_u + feat_v))  # (num_nodes, 1)
     alpha = e * g.ndata['cnt'].view_as(e)
     rst = F.segment.segment_reduce(g.batch_num_nodes(), feat * alpha, 'sum')
     if self.fc_out is not None:
         rst = self.fc_out(rst)
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
Exemplo n.º 12
0
 def forward(self, graph: dgl.DGLGraph, feat, lambda_max=None):
     shp = (len(graph.nodes()), ) + tuple(1 for _ in range(feat.dim() - 1))
     with graph.local_scope():
         norm = torch.pow(graph.in_degrees().float().clamp(min=1),
                          -0.5).reshape(shp).to(feat.device)
         if lambda_max is None:
             try:
                 lambda_max = laplacian_lambda_max(graph)
             except ArpackNoConvergence:
                 lambda_max = [2.] * graph.batch_size
         if isinstance(lambda_max, list):
             lambda_max = torch.tensor(lambda_max).to(feat.device)
         if lambda_max.dim() < 1:
             lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)
         # broadcast from (B, 1) to (N, 1)
         lambda_max = torch.reshape(broadcast_nodes(graph, lambda_max),
                                    shp).float()
         # T0(X)
         Tx_0 = feat
         rst = self.fc[0](Tx_0)
         # T1(X)
         if self._k > 1:
             graph.ndata['h'] = Tx_0 * norm
             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
             h = graph.ndata.pop('h') * norm
             # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I
             #   = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
             Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1)
             rst = rst + self.fc[1](Tx_1)
         # Ti(x), i = 2...k
         for i in range(2, self._k):
             graph.ndata['h'] = Tx_1 * norm
             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
             h = graph.ndata.pop('h') * norm
             # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
             #      = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +
             #        (4 / lambda_max - 2) Tx_(k-1) -
             #        Tx_(k-2)
             Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max -
                                                   2) - Tx_0
             rst = rst + self.fc[i](Tx_2)
             Tx_1, Tx_0 = Tx_2, Tx_1
         # add bias
         if self.bias is not None:
             rst = rst + self.bias
         return rst
Exemplo n.º 13
0
def test_broadcast(idtype, g):
    g = g.astype(idtype).to(F.ctx())
    gfeat = F.randn((g.batch_size, 3))

    # Test.0: broadcast_nodes
    g.ndata['h'] = dgl.broadcast_nodes(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            sg.ndata['h'],
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_nodes(), dim=0))

    # Test.1: broadcast_edges
    g.edata['h'] = dgl.broadcast_edges(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            sg.edata['h'],
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_edges(), dim=0))
Exemplo n.º 14
0
 def forward(self, g, feat_i, feat_u, last_nodes):
     if self.batch_norm is not None:
         feat_i = self.batch_norm['item'](feat_i)
         feat_u = self.batch_norm['user'](feat_u)
     if self.feat_drop is not None:
         feat_i = self.feat_drop(feat_i)
         feat_u = self.feat_drop(feat_u)
     feat_val = feat_i
     feat_key = self.fc_key(feat_i)
     feat_u = self.fc_user(feat_u)
     feat_last = self.fc_last(feat_i[last_nodes])
     feat_qry = dgl.broadcast_nodes(g, feat_u + feat_last)
     e = self.fc_e(th.sigmoid(feat_qry + feat_key))  # (num_nodes, 1)
     e = e + g.ndata['cnt'].log().view_as(e)
     alpha = F.segment.segment_softmax(g.batch_num_nodes(), e)
     rst = F.segment.segment_reduce(g.batch_num_nodes(), alpha * feat_val,
                                    'sum')
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
    def forward(self, graphs, nodes_feat, edges_feat, nodes_num_norm_sqrt,
                edges_num_norm_sqrt):
        h = self.embedding_h(nodes_feat)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(graphs, h, nodes_num_norm_sqrt)
            pass

        graphs.ndata['h'] = h
        h_mean = dgl.mean_nodes(graphs, 'h')
        h_mean = dgl.broadcast_nodes(graphs, h_mean)
        h_mean_and_h = torch.cat([h, h_mean], dim=-1)

        nodes_attention = 2 * torch.sigmoid(
            self.attention2(torch.relu(self.attention1(h_mean_and_h)))) - 1
        h = h + h * nodes_attention
        graphs.ndata['h'] = h
        hg = dgl.mean_nodes(graphs, 'h')

        logits = self.readout_mlp(hg)
        return logits
    def forward(self, graph, global_attr=None, out_node_key='h_v'):
        def recv_func(nodes):
            nodes_to_collect = []
            num_nodes = nodes.data[self.node_key].shape[0]
            if self._use_nodes:
                nodes_to_collect.append(nodes.data[self.node_key])


#             if self._use_sent_edges:
#                 agg_edge_attr = getattr(torch, self._sent_edges_reducer)(nodes.mailbox["m"], dim=1)
#                 nodes_to_collect.append(agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1]))
            if self._use_received_edges:
                agg_edge_attr = getattr(torch, self._received_edges_reducer)(
                    nodes.mailbox["m"], dim=1)
                nodes_to_collect.append(
                    agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1]))
            if self._use_globals and global_attr is not None:
                # self._global_attr = global_attr.unsqueeze(0)    # make global_attr.shape = (1, DIM)
                # expanded_global_attr = self._global_attr.expand(num_nodes, self._global_attr.shape[1])
                expanded_global_attr = nodes.data['expanded_global_attr']
                nodes_to_collect.append(expanded_global_attr)

            collected_nodes = torch.cat(nodes_to_collect, dim=-1)

            if self.recurrent:
                return {
                    out_node_key:
                    self.net(collected_nodes, nodes.data[out_node_key])
                }
            else:
                return {out_node_key: self.net(collected_nodes)}

        graph.ndata['expanded_global_attr'] = dgl.broadcast_nodes(
            graph, global_attr)
        if self._use_received_edges:
            graph.update_all(fn.copy_e(self.edge_key, "m"), recv_func)  # trick
        else:
            graph.apply_nodes(recv_func)

        return graph
Exemplo n.º 17
0
    def forward(self, graph, feat):
        with graph.local_scope():
            batch_size = graph.batch_size

            h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
                 feat.new_zeros((self.n_layers, batch_size, self.input_dim))
                 )  #(6, 32, 100)

            q_star = feat.new_zeros(batch_size, self.output_dim)  #(32, 200)
            #print(q_star.shape)
            for i in range(self.n_iters):
                q, h = self.lstm(q_star.unsqueeze(0), h)
                q = q.view(batch_size, self.input_dim)
                e = (feat * dgl.broadcast_nodes(graph, q)).sum(dim=-1,
                                                               keepdim=True)

                graph.ndata['e'] = e
                alpha = dgl.softmax_nodes(graph, 'e')
                graph.ndata['r'] = feat * alpha
                readout = dgl.sum_nodes(graph, 'r')
                q_star = torch.cat([q, readout], dim=-1)

            return q_star
Exemplo n.º 18
0
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        """Perform one-step readout

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feat_size)
            Input node features. V for the number of nodes.
        g_feats : float32 tensor of shape (G, graph_feat_size)
            Input graph features. G for the number of graphs.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        Returns
        -------
        float32 tensor of shape (G, graph_feat_size)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        """
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)

            if isinstance(g, BatchedDGLGraph):
                g_repr = dgl.sum_nodes(g, 'hv', 'a')
            else:
                g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
            context = F.elu(g_repr)

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
            else:
                return self.gru(context, g_feats)
Exemplo n.º 19
0
    def forward(self, graph: dgl.DGLGraph, feat: torch.Tensor) -> torch.Tensor:
        """
        Compute set2set pooling.

        Args:
            graph: the input graph
            feat: The input feature with shape :math:`(N, D)` where  :math:`N` is the
                number of nodes in the graph, and :math:`D` means the size of features.

        Returns:
            The output feature with shape :math:`(B, D)`, where :math:`B` refers to
            the batch size, and :math:`D` means the size of features.
        """
        with graph.local_scope():
            batch_size = graph.batch_size

            h = (
                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
            )

            q_star = feat.new_zeros(batch_size, self.output_dim)

            for _ in range(self.n_iters):
                q, h = self.lstm(q_star.unsqueeze(0), h)
                q = q.view(batch_size, self.input_dim)
                e = (feat *
                     dgl.broadcast_nodes(graph, q, ntype=self.ntype)).sum(
                         dim=-1, keepdim=True)
                graph.nodes[self.ntype].data["e"] = e
                alpha = dgl.softmax_nodes(graph, "e", ntype=self.ntype)
                graph.nodes[self.ntype].data["r"] = feat * alpha
                readout = dgl.sum_nodes(graph, "r", ntype=self.ntype)
                q_star = torch.cat([q, readout], dim=-1)

            return q_star
Exemplo n.º 20
0
    def forward(self, g, feature, e):
        h_in = feature  # to be used for residual connection
        lambda_max = [2] * g.batch_size

        def unnLaplacian(feature, D_sqrt, graph):
            """ Operation D^-1/2 A D^-1/2 """
            graph.ndata['h'] = feature * D_sqrt
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            return graph.ndata.pop('h') * D_sqrt

        with g.local_scope():
            D_sqrt = torch.pow(g.in_degrees().float().clamp(
                min=1), -0.5).unsqueeze(-1).to(feature.device)

            lambda_max = [2] * g.batch_size
            
            if lambda_max is None:
                try:
                    lambda_max = dgl.laplacian_lambda_max(g)
                except BaseException:
                    # if the largest eigonvalue is not found
                    lambda_max = [2]

            if isinstance(lambda_max, list):
                lambda_max = torch.Tensor(lambda_max).to(feature.device)
            if lambda_max.dim() == 1:
                lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)

            # broadcast from (B, 1) to (N, 1)
            lambda_max = dgl.broadcast_nodes(g, lambda_max)

            # X_0(f)
            Xt = X_0 = feature

            # X_1(f)
            if self._k > 1:
                re_norm = (2. / lambda_max).to(feature.device)
                h = unnLaplacian(X_0, D_sqrt, g)
                # print('h',h,'norm',re_norm,'X0',X_0)
                X_1 = - re_norm * h + X_0 * (re_norm - 1)

                Xt = torch.cat((Xt, X_1), 1)

            # Xi(x), i = 2...k
            for _ in range(2, self._k):
                h = unnLaplacian(X_1, D_sqrt, g)
                X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0

                Xt = torch.cat((Xt, X_i), 1)
                X_1, X_0 = X_i, X_1

            h = self.linear(Xt)

        if self.batch_norm:
            h = self.batchnorm_h(h)  # batch normalization

        if self.activation:
            h = self.activation(h)

        if self.residual:
            h = h_in + h  # residual connection

        h = self.dropout(h)
        return h, e
Exemplo n.º 21
0
    def forward(self, g, node_feats, edge_feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : LongTensor of shape (N, 1)
            Input categorical node features. N for the number of nodes.
        edge_feats : FloatTensor of shape (E, in_edge_feats)
            Input edge features. E for the number of edges.

        Returns
        -------
        FloatTensor of shape (N, hidden_feats)
            Output node representations
        """
        if self.gnn_type == 'gcn':
            degs = (g.in_degrees().float() + 1).to(node_feats.device)
            norm = torch.pow(degs, -0.5).unsqueeze(-1)  # (N, 1)
            g.ndata['norm'] = norm
            g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm'))
            norm = g.edata.pop('norm')

        if self.virtual_node:
            virtual_node_feats = self.virtual_node_emb(
                torch.zeros(g.batch_size).to(node_feats.dtype).to(
                    node_feats.device))
        h_list = [self.node_encoder(node_feats)]

        for l in range(len(self.layers)):
            if self.virtual_node:
                virtual_feats_broadcast = dgl.broadcast_nodes(
                    g, virtual_node_feats)
                h_list[l] = h_list[l] + virtual_feats_broadcast

            if self.gnn_type == 'gcn':
                h = self.layers[l](g, h_list[l], edge_feats, degs, norm)
            else:
                h = self.layers[l](g, h_list[l], edge_feats)

            if self.batchnorms is not None:
                h = self.batchnorms[l](h)

            if self.activation is not None and l != self.n_layers - 1:
                h = self.activation(h)
            h = self.dropout(h)
            h_list.append(h)

            if l < self.n_layers - 1 and self.virtual_node:
                ### Update virtual node representation from real node representations
                virtual_node_feats_tmp = self.virtual_readout(
                    g, h_list[l]) + virtual_node_feats
                if self.residual:
                    virtual_node_feats = virtual_node_feats + self.dropout(
                        self.mlp_virtual_project[l](virtual_node_feats_tmp))
                else:
                    virtual_node_feats = self.dropout(
                        self.mlp_virtual_project[l](virtual_node_feats_tmp))

        if self.jk:
            return torch.stack(h_list, dim=0).sum(0)
        else:
            return h_list[-1]
Exemplo n.º 22
0
    def forward(self, graph, feat, lambda_max=None):
        r"""

        Description
        -----------
        Compute ChebNet layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor
            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
            is size of input feature, :math:`N` is the number of nodes.
        lambda_max : list or tensor or None, optional.
            A list(tensor) with length :math:`B`, stores the largest eigenvalue
            of the normalized laplacian of each individual graph in ``graph``,
            where :math:`B` is the batch size of the input graph. Default: None.
            If None, this method would compute the list by calling
            ``dgl.laplacian_lambda_max``.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        def unnLaplacian(feat, D_invsqrt, graph):
            """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat"""
            graph.ndata['h'] = feat * D_invsqrt
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            return graph.ndata.pop('h') * D_invsqrt

        with graph.local_scope():
            #一点修改,这是原来的代码
            if self.is_mnist:
                graph.update_all(fn.copy_edge('v', 'm'),
                                 fn.sum('m', 'h'))  # 'v'与coordinate.py有关
                D_invsqrt = th.pow(
                    graph.ndata.pop('h').float().clamp(min=1),
                    -0.5).unsqueeze(-1).to(feat.device)

            #D_invsqrt = th.pow(graph.in_degrees().float().clamp(
            #   min=1), -0.5).unsqueeze(-1).to(feat.device)
            #print("in_degree : ",graph.in_degrees().shape)
            else:
                D_invsqrt = th.pow(graph.in_degrees().float().clamp(min=1),
                                   -0.5).unsqueeze(-1).to(feat.device)
            #print("D_invsqrt : ",D_invsqrt.shape)
            #print("ndata : ",graph.ndata['h'].shape)
            if lambda_max is None:
                try:
                    lambda_max = laplacian_lambda_max(graph)
                except BaseException:
                    # if the largest eigenvalue is not found
                    dgl_warning(
                        "Largest eigonvalue not found, using default value 2 for lambda_max",
                        RuntimeWarning)
                    lambda_max = th.Tensor(2).to(feat.device)

            if isinstance(lambda_max, list):
                lambda_max = th.Tensor(lambda_max).to(feat.device)
            if lambda_max.dim() == 1:
                lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)

            # broadcast from (B, 1) to (N, 1)
            lambda_max = broadcast_nodes(graph, lambda_max)
            re_norm = 2. / lambda_max

            # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
            Xt = X_0 = feat

            # X_1(f)
            if self._k > 1:
                h = unnLaplacian(X_0, D_invsqrt, graph)
                X_1 = -re_norm * h + X_0 * (re_norm - 1)
                # Concatenate Xt and X_1
                Xt = th.cat((Xt, X_1), 1)

            # Xi(x), i = 2...k
            for _ in range(2, self._k):
                h = unnLaplacian(X_1, D_invsqrt, graph)
                X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
                # Concatenate Xt and X_i
                Xt = th.cat((Xt, X_i), 1)
                X_1, X_0 = X_i, X_1

            # linear projection
            h = self.linear(Xt)

            # activation
            if self.activation:
                h = self.activation(h)
        #print('ChebConv.py Line163 h : ',h.shape)
        return h