Exemplo n.º 1
    def forward(self, input_hidden, graphs: dgl.DGLGraph, batch_num_nodes=None):
        if batch_num_nodes is None:
            b_num_nodes = graphs.batch_num_nodes
            b_num_nodes = batch_num_nodes
        h_t = self.input_proj(input_hidden)
        # when there are no edges in the graph, there is nothing to do
        if graphs.number_of_edges() > 0:
            #give all the nodes an edges information about the current querry hidden state
            broadcasted_hn = dgl.broadcast_nodes(graphs, h_t)
            graphs.ndata['h_t'] = broadcasted_hn
            broadcasted_he = dgl.broadcast_edges(graphs, h_t)
            graphs.edata['h_t'] = broadcasted_he
            # create a copy of the node and edge states which will be updated for K iterations
            graphs.ndata['F_n_t'] = graphs.ndata['F_n']
            graphs.edata['F_e_t'] = graphs.edata['F_e']

            for _ in range(self.k_update_steps):
                graphs.ndata['s_n'] = self.object_score(torch.cat([graphs.ndata['h_t'], graphs.ndata['F_n_t']], dim=-1))
                graphs.ndata['F_n_t'] = graphs.ndata['F_i_tplus1']
                if self.update_relations:
                    graphs.edata['F_e_t'] = graphs.edata['F_e_tplus1']

            io = torch.split(graphs.ndata['F_n_t'], split_size_or_sections=b_num_nodes)
            io = torch.split(graphs.ndata['F_n'], split_size_or_sections=b_num_nodes)
        io = pad_sequence(io, batch_first=True)
        io_mask = io.sum(dim=-1) != 0

        return io, io_mask
Exemplo n.º 2
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        g : DGLGraph or BatchedDGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.
        g_feats : float32 tensor of shape (G, N2)
            Input graph features. G for the number of graphs and N2 for the feature size.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        float32 tensor of shape (G, N2)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)
            context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
                return self.gru(context, g_feats)
Exemplo n.º 3
    def forward(self, graph, node_feat, edge_feat):
        if self.virtual_node:
            virtual_emb = self.virtual_emb.weight.expand(graph.batch_size, -1)

        hn = self.node_encoder(node_feat)

        for layer in range(self.num_layers):

            if self.virtual_node:
                # messages from virtual nodes to graph nodes
                virtual_hn = dgl.broadcast_nodes(graph, virtual_emb)
                hn = hn + virtual_hn

            he = self.edge_encoders[layer](edge_feat)
            hn = self.conv_layers[layer](graph, hn, he)
            if layer != self.num_layers - 1:
                hn = F.relu(hn)
            hn = self.dropout(hn)

            if self.virtual_node and layer != self.num_layers - 1:
                # messages from graph nodes to virtual nodes
                virtual_emb_tmp = self.virtual_pool(graph, hn) + virtual_emb
                virtual_emb = self.mlp_virtual[layer](virtual_emb_tmp)
                virtual_emb = self.dropout(F.relu(virtual_emb))

        hg = self.pool(graph, hn)

        return self.pred(hg)
Exemplo n.º 4
    def forward(self, inputs, extra_inputs=None):
        KG_embeddings = super().forward(extra_inputs)

        uid, g = inputs
        iid = g.ndata['iid']  # (num_nodes,)
        feat_i = KG_embeddings['item'][iid]
        feat_u = KG_embeddings['user'][uid]
        feat = self.fc_i(feat_i) + dgl.broadcast_nodes(g, self.fc_u(feat_u))
        feat_i = self.PSE_layer(g, feat)
        sr = th.cat([feat_i, feat_u], dim=1)
        logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t()
        return logits
Exemplo n.º 5
Arquivo: lessr.py Projeto: lessr/lessr
 def forward(self, g, feat, last_nodes):
     with g.local_scope():
         if self.batch_norm is not None:
             feat = self.batch_norm(feat)
         feat_u = self.fc_u(feat)
         feat_v = self.fc_v(feat[last_nodes])
         feat_v = dgl.broadcast_nodes(g, feat_v)
         g.ndata['e'] = self.attn_e(th.sigmoid(feat_u + feat_v))
         alpha = dgl.softmax_nodes(g, 'e')
         g.ndata['w'] = feat * alpha
         rst = dgl.sum_nodes(g, 'w')
         rst = self.fc_out(rst)
         return rst
Exemplo n.º 6
    def forward(self, g, x, edge_attr):
        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(

        h_list = [self.atom_encoder(x)]
        batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device))
        for layer in range(self.num_layers):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]

            ### Message passing among graph nodes
            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]


            ### update the virtual nodes
            if layer < self.num_layers - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding
                ### transform virtual nodes using MLP
                virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(
                        virtualnode_embedding_temp, self.drop_ratio, training = self.training)
                    virtualnode_embedding = F.dropout(
                        virtualnode_embedding_temp, self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation
Exemplo n.º 7
def collate(samples):
    ''' collate function for building graph dataloader '''

    # generate batched graphs and labels
    graphs, targets = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    batched_targets = th.Tensor(targets)
    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
    batched_graph.ndata['graph_id'] = graph_id
    return batched_graph, batched_targets
Exemplo n.º 8
def test_broadcast_nodes():
    # test#1: basic
    g0 = dgl.DGLGraph(nx.path_graph(10))
    feat0 = F.randn((40, ))
    ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
    assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(3))
    g2 = dgl.DGLGraph()
    g3 = dgl.DGLGraph(nx.path_graph(12))
    bg = dgl.batch([g0, g1, g2, g3])
    feat1 = F.randn((40, ))
    feat2 = F.randn((40, ))
    feat3 = F.randn((40, ))
    ground_truth = F.stack(
        [feat0] * g0.number_of_nodes() +\
        [feat1] * g1.number_of_nodes() +\
        [feat2] * g2.number_of_nodes() +\
        [feat3] * g3.number_of_nodes(), 0
    assert F.allclose(
        dgl.broadcast_nodes(bg, F.stack([feat0, feat1, feat2, feat3], 0)),
Exemplo n.º 9
 def forward(self, g, feat, last_nodes):
     if self.batch_norm is not None:
         feat = self.batch_norm(feat)
     feat = self.feat_drop(feat)
     feat_u = self.fc_u(feat)
     feat_v = self.fc_v(feat[last_nodes])
     feat_v = dgl.broadcast_nodes(g, feat_v)
     e = self.fc_e(th.sigmoid(feat_u + feat_v))
     alpha = F.segment.segment_softmax(g.batch_num_nodes(), e)
     feat_norm = feat * alpha
     rst = F.segment.segment_reduce(g.batch_num_nodes(), feat_norm, 'sum')
     if self.fc_out is not None:
         rst = self.fc_out(rst)
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
Exemplo n.º 10
def collate(samples):
    ''' collate function for building the graph dataloader'''
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata['graph_id'] = graph_id

    return batched_graph, batched_diff_graph, batched_labels
Exemplo n.º 11
 def forward(self, g, feat, last_nodes):
     if self.batch_norm is not None:
         feat = self.batch_norm(feat)
     if self.feat_drop is not None:
         feat = self.feat_drop(feat)
     feat_u = self.fc_u(feat)
     feat_v = self.fc_v(feat[last_nodes])
     feat_v = dgl.broadcast_nodes(g, feat_v)
     e = self.fc_e(th.sigmoid(feat_u + feat_v))  # (num_nodes, 1)
     alpha = e * g.ndata['cnt'].view_as(e)
     rst = F.segment.segment_reduce(g.batch_num_nodes(), feat * alpha, 'sum')
     if self.fc_out is not None:
         rst = self.fc_out(rst)
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
Exemplo n.º 12
 def forward(self, graph: dgl.DGLGraph, feat, lambda_max=None):
     shp = (len(graph.nodes()), ) + tuple(1 for _ in range(feat.dim() - 1))
     with graph.local_scope():
         norm = torch.pow(graph.in_degrees().float().clamp(min=1),
         if lambda_max is None:
                 lambda_max = laplacian_lambda_max(graph)
             except ArpackNoConvergence:
                 lambda_max = [2.] * graph.batch_size
         if isinstance(lambda_max, list):
             lambda_max = torch.tensor(lambda_max).to(feat.device)
         if lambda_max.dim() < 1:
             lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)
         # broadcast from (B, 1) to (N, 1)
         lambda_max = torch.reshape(broadcast_nodes(graph, lambda_max),
         # T0(X)
         Tx_0 = feat
         rst = self.fc[0](Tx_0)
         # T1(X)
         if self._k > 1:
             graph.ndata['h'] = Tx_0 * norm
             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
             h = graph.ndata.pop('h') * norm
             # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I
             #   = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I
             Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1)
             rst = rst + self.fc[1](Tx_1)
         # Ti(x), i = 2...k
         for i in range(2, self._k):
             graph.ndata['h'] = Tx_1 * norm
             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
             h = graph.ndata.pop('h') * norm
             # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)
             #      = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +
             #        (4 / lambda_max - 2) Tx_(k-1) -
             #        Tx_(k-2)
             Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max -
                                                   2) - Tx_0
             rst = rst + self.fc[i](Tx_2)
             Tx_1, Tx_0 = Tx_2, Tx_1
         # add bias
         if self.bias is not None:
             rst = rst + self.bias
         return rst
Exemplo n.º 13
def test_broadcast(idtype, g):
    g = g.astype(idtype).to(F.ctx())
    gfeat = F.randn((g.batch_size, 3))

    # Test.0: broadcast_nodes
    g.ndata['h'] = dgl.broadcast_nodes(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_nodes(), dim=0))

    # Test.1: broadcast_edges
    g.edata['h'] = dgl.broadcast_edges(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_edges(), dim=0))
Exemplo n.º 14
 def forward(self, g, feat_i, feat_u, last_nodes):
     if self.batch_norm is not None:
         feat_i = self.batch_norm['item'](feat_i)
         feat_u = self.batch_norm['user'](feat_u)
     if self.feat_drop is not None:
         feat_i = self.feat_drop(feat_i)
         feat_u = self.feat_drop(feat_u)
     feat_val = feat_i
     feat_key = self.fc_key(feat_i)
     feat_u = self.fc_user(feat_u)
     feat_last = self.fc_last(feat_i[last_nodes])
     feat_qry = dgl.broadcast_nodes(g, feat_u + feat_last)
     e = self.fc_e(th.sigmoid(feat_qry + feat_key))  # (num_nodes, 1)
     e = e + g.ndata['cnt'].log().view_as(e)
     alpha = F.segment.segment_softmax(g.batch_num_nodes(), e)
     rst = F.segment.segment_reduce(g.batch_num_nodes(), alpha * feat_val,
     if self.activation is not None:
         rst = self.activation(rst)
     return rst
    def forward(self, graphs, nodes_feat, edges_feat, nodes_num_norm_sqrt,
        h = self.embedding_h(nodes_feat)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(graphs, h, nodes_num_norm_sqrt)

        graphs.ndata['h'] = h
        h_mean = dgl.mean_nodes(graphs, 'h')
        h_mean = dgl.broadcast_nodes(graphs, h_mean)
        h_mean_and_h = torch.cat([h, h_mean], dim=-1)

        nodes_attention = 2 * torch.sigmoid(
            self.attention2(torch.relu(self.attention1(h_mean_and_h)))) - 1
        h = h + h * nodes_attention
        graphs.ndata['h'] = h
        hg = dgl.mean_nodes(graphs, 'h')

        logits = self.readout_mlp(hg)
        return logits
    def forward(self, graph, global_attr=None, out_node_key='h_v'):
        def recv_func(nodes):
            nodes_to_collect = []
            num_nodes = nodes.data[self.node_key].shape[0]
            if self._use_nodes:

#             if self._use_sent_edges:
#                 agg_edge_attr = getattr(torch, self._sent_edges_reducer)(nodes.mailbox["m"], dim=1)
#                 nodes_to_collect.append(agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1]))
            if self._use_received_edges:
                agg_edge_attr = getattr(torch, self._received_edges_reducer)(
                    nodes.mailbox["m"], dim=1)
                    agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1]))
            if self._use_globals and global_attr is not None:
                # self._global_attr = global_attr.unsqueeze(0)    # make global_attr.shape = (1, DIM)
                # expanded_global_attr = self._global_attr.expand(num_nodes, self._global_attr.shape[1])
                expanded_global_attr = nodes.data['expanded_global_attr']

            collected_nodes = torch.cat(nodes_to_collect, dim=-1)

            if self.recurrent:
                return {
                    self.net(collected_nodes, nodes.data[out_node_key])
                return {out_node_key: self.net(collected_nodes)}

        graph.ndata['expanded_global_attr'] = dgl.broadcast_nodes(
            graph, global_attr)
        if self._use_received_edges:
            graph.update_all(fn.copy_e(self.edge_key, "m"), recv_func)  # trick

        return graph
Exemplo n.º 17
    def forward(self, graph, feat):
        with graph.local_scope():
            batch_size = graph.batch_size

            h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
                 feat.new_zeros((self.n_layers, batch_size, self.input_dim))
                 )  #(6, 32, 100)

            q_star = feat.new_zeros(batch_size, self.output_dim)  #(32, 200)
            for i in range(self.n_iters):
                q, h = self.lstm(q_star.unsqueeze(0), h)
                q = q.view(batch_size, self.input_dim)
                e = (feat * dgl.broadcast_nodes(graph, q)).sum(dim=-1,

                graph.ndata['e'] = e
                alpha = dgl.softmax_nodes(graph, 'e')
                graph.ndata['r'] = feat * alpha
                readout = dgl.sum_nodes(graph, 'r')
                q_star = torch.cat([q, readout], dim=-1)

            return q_star
Exemplo n.º 18
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        """Perform one-step readout

        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feat_size)
            Input node features. V for the number of nodes.
        g_feats : float32 tensor of shape (G, graph_feat_size)
            Input graph features. G for the number of graphs.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        float32 tensor of shape (G, graph_feat_size)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)

            if isinstance(g, BatchedDGLGraph):
                g_repr = dgl.sum_nodes(g, 'hv', 'a')
                g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0)
            context = F.elu(g_repr)

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
                return self.gru(context, g_feats)
Exemplo n.º 19
    def forward(self, graph: dgl.DGLGraph, feat: torch.Tensor) -> torch.Tensor:
        Compute set2set pooling.

            graph: the input graph
            feat: The input feature with shape :math:`(N, D)` where  :math:`N` is the
                number of nodes in the graph, and :math:`D` means the size of features.

            The output feature with shape :math:`(B, D)`, where :math:`B` refers to
            the batch size, and :math:`D` means the size of features.
        with graph.local_scope():
            batch_size = graph.batch_size

            h = (
                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),

            q_star = feat.new_zeros(batch_size, self.output_dim)

            for _ in range(self.n_iters):
                q, h = self.lstm(q_star.unsqueeze(0), h)
                q = q.view(batch_size, self.input_dim)
                e = (feat *
                     dgl.broadcast_nodes(graph, q, ntype=self.ntype)).sum(
                         dim=-1, keepdim=True)
                graph.nodes[self.ntype].data["e"] = e
                alpha = dgl.softmax_nodes(graph, "e", ntype=self.ntype)
                graph.nodes[self.ntype].data["r"] = feat * alpha
                readout = dgl.sum_nodes(graph, "r", ntype=self.ntype)
                q_star = torch.cat([q, readout], dim=-1)

            return q_star
Exemplo n.º 20
    def forward(self, g, feature, e):
        h_in = feature  # to be used for residual connection
        lambda_max = [2] * g.batch_size

        def unnLaplacian(feature, D_sqrt, graph):
            """ Operation D^-1/2 A D^-1/2 """
            graph.ndata['h'] = feature * D_sqrt
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            return graph.ndata.pop('h') * D_sqrt

        with g.local_scope():
            D_sqrt = torch.pow(g.in_degrees().float().clamp(
                min=1), -0.5).unsqueeze(-1).to(feature.device)

            lambda_max = [2] * g.batch_size
            if lambda_max is None:
                    lambda_max = dgl.laplacian_lambda_max(g)
                except BaseException:
                    # if the largest eigonvalue is not found
                    lambda_max = [2]

            if isinstance(lambda_max, list):
                lambda_max = torch.Tensor(lambda_max).to(feature.device)
            if lambda_max.dim() == 1:
                lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)

            # broadcast from (B, 1) to (N, 1)
            lambda_max = dgl.broadcast_nodes(g, lambda_max)

            # X_0(f)
            Xt = X_0 = feature

            # X_1(f)
            if self._k > 1:
                re_norm = (2. / lambda_max).to(feature.device)
                h = unnLaplacian(X_0, D_sqrt, g)
                # print('h',h,'norm',re_norm,'X0',X_0)
                X_1 = - re_norm * h + X_0 * (re_norm - 1)

                Xt = torch.cat((Xt, X_1), 1)

            # Xi(x), i = 2...k
            for _ in range(2, self._k):
                h = unnLaplacian(X_1, D_sqrt, g)
                X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0

                Xt = torch.cat((Xt, X_i), 1)
                X_1, X_0 = X_i, X_1

            h = self.linear(Xt)

        if self.batch_norm:
            h = self.batchnorm_h(h)  # batch normalization

        if self.activation:
            h = self.activation(h)

        if self.residual:
            h = h_in + h  # residual connection

        h = self.dropout(h)
        return h, e
Exemplo n.º 21
    def forward(self, g, node_feats, edge_feats):
        """Update node representations.

        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : LongTensor of shape (N, 1)
            Input categorical node features. N for the number of nodes.
        edge_feats : FloatTensor of shape (E, in_edge_feats)
            Input edge features. E for the number of edges.

        FloatTensor of shape (N, hidden_feats)
            Output node representations
        if self.gnn_type == 'gcn':
            degs = (g.in_degrees().float() + 1).to(node_feats.device)
            norm = torch.pow(degs, -0.5).unsqueeze(-1)  # (N, 1)
            g.ndata['norm'] = norm
            g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm'))
            norm = g.edata.pop('norm')

        if self.virtual_node:
            virtual_node_feats = self.virtual_node_emb(
        h_list = [self.node_encoder(node_feats)]

        for l in range(len(self.layers)):
            if self.virtual_node:
                virtual_feats_broadcast = dgl.broadcast_nodes(
                    g, virtual_node_feats)
                h_list[l] = h_list[l] + virtual_feats_broadcast

            if self.gnn_type == 'gcn':
                h = self.layers[l](g, h_list[l], edge_feats, degs, norm)
                h = self.layers[l](g, h_list[l], edge_feats)

            if self.batchnorms is not None:
                h = self.batchnorms[l](h)

            if self.activation is not None and l != self.n_layers - 1:
                h = self.activation(h)
            h = self.dropout(h)

            if l < self.n_layers - 1 and self.virtual_node:
                ### Update virtual node representation from real node representations
                virtual_node_feats_tmp = self.virtual_readout(
                    g, h_list[l]) + virtual_node_feats
                if self.residual:
                    virtual_node_feats = virtual_node_feats + self.dropout(
                    virtual_node_feats = self.dropout(

        if self.jk:
            return torch.stack(h_list, dim=0).sum(0)
            return h_list[-1]
Exemplo n.º 22
    def forward(self, graph, feat, lambda_max=None):

        Compute ChebNet layer.

        graph : DGLGraph
            The graph.
        feat : torch.Tensor
            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
            is size of input feature, :math:`N` is the number of nodes.
        lambda_max : list or tensor or None, optional.
            A list(tensor) with length :math:`B`, stores the largest eigenvalue
            of the normalized laplacian of each individual graph in ``graph``,
            where :math:`B` is the batch size of the input graph. Default: None.
            If None, this method would compute the list by calling

            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        def unnLaplacian(feat, D_invsqrt, graph):
            """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat"""
            graph.ndata['h'] = feat * D_invsqrt
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            return graph.ndata.pop('h') * D_invsqrt

        with graph.local_scope():
            if self.is_mnist:
                graph.update_all(fn.copy_edge('v', 'm'),
                                 fn.sum('m', 'h'))  # 'v'与coordinate.py有关
                D_invsqrt = th.pow(

            #D_invsqrt = th.pow(graph.in_degrees().float().clamp(
            #   min=1), -0.5).unsqueeze(-1).to(feat.device)
            #print("in_degree : ",graph.in_degrees().shape)
                D_invsqrt = th.pow(graph.in_degrees().float().clamp(min=1),
            #print("D_invsqrt : ",D_invsqrt.shape)
            #print("ndata : ",graph.ndata['h'].shape)
            if lambda_max is None:
                    lambda_max = laplacian_lambda_max(graph)
                except BaseException:
                    # if the largest eigenvalue is not found
                        "Largest eigonvalue not found, using default value 2 for lambda_max",
                    lambda_max = th.Tensor(2).to(feat.device)

            if isinstance(lambda_max, list):
                lambda_max = th.Tensor(lambda_max).to(feat.device)
            if lambda_max.dim() == 1:
                lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)

            # broadcast from (B, 1) to (N, 1)
            lambda_max = broadcast_nodes(graph, lambda_max)
            re_norm = 2. / lambda_max

            # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
            Xt = X_0 = feat

            # X_1(f)
            if self._k > 1:
                h = unnLaplacian(X_0, D_invsqrt, graph)
                X_1 = -re_norm * h + X_0 * (re_norm - 1)
                # Concatenate Xt and X_1
                Xt = th.cat((Xt, X_1), 1)

            # Xi(x), i = 2...k
            for _ in range(2, self._k):
                h = unnLaplacian(X_1, D_invsqrt, graph)
                X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
                # Concatenate Xt and X_i
                Xt = th.cat((Xt, X_i), 1)
                X_1, X_0 = X_i, X_1

            # linear projection
            h = self.linear(Xt)

            # activation
            if self.activation:
                h = self.activation(h)
        #print('ChebConv.py Line163 h : ',h.shape)
        return h