Пример #1
0
    def forward(self, g, node_feats, edge_feats):
        with g.local_scope():
            # Node and edge feature dimension need to match.
            g.ndata['h'] = node_feats
            g.edata['h'] = self.edge_encoder(edge_feats)
            g.apply_edges(fn.u_add_e('h', 'h', 'm'))

            if self.aggr == 'softmax':
                g.edata['m'] = F.relu(g.edata['m']) + self.eps
                g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta)
                g.update_all(
                    lambda edge: {'x': edge.data['m'] * edge.data['a']},
                    fn.sum('x', 'm'))

            elif self.aggr == 'power':
                minv, maxv = 1e-7, 1e1
                torch.clamp_(g.edata['m'], minv, maxv)
                g.update_all(
                    lambda edge: {'x': torch.pow(edge.data['m'], self.p)},
                    fn.mean('x', 'm'))
                torch.clamp_(g.ndata['m'], minv, maxv)
                g.ndata['m'] = torch.pow(g.ndata['m'], self.p)

            else:
                raise NotImplementedError(
                    f'Aggregator {self.aggr} is not supported.')

            if self.msg_norm is not None:
                g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m'])

            feats = node_feats + g.ndata['m']

            return self.mlp(feats)
Пример #2
0
    def forward(self, g, split_list, node_feat, edge_feat):
        graph = g.local_var()
        graph.ndata['h_n'] = node_feat
        graph.edata['h_e'] = edge_feat
        graph.update_all(fn.u_add_e('h_n', 'h_e', 'm'),
                         self._reducer('m', 'neigh'))
        rst = (1 + self.eps) * node_feat + graph.ndata['neigh']

        if self.apply_func is not None:
            rst = self.apply_func(g, rst)
        return rst
Пример #3
0
    def forward(self, graph, node_feat, edge_feat):
        graph = graph.local_var()
        graph.ndata['h_n'] = node_feat
        graph.edata['h_e'] = edge_feat

        ### u, v, e represent source nodes, destination nodes and edges among them
        graph.update_all(fn.u_add_e('h_n', 'h_e', 'm'), fn.sum('m', 'neigh'))
        rst = (1 + self.eps) * node_feat + graph.ndata['neigh']
        rst = self.mlp(rst)

        return rst
Пример #4
0
    def forward(self, g, node_feats, categorical_edge_feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : FloatTensor of shape (N, emb_dim)
            * Input node features
            * N is the total number of nodes in the batch of graphs
            * emb_dim is the input node feature size, which must match emb_dim in initialization
        categorical_edge_feats : list of LongTensor of shape (E)
            * Input categorical edge features
            * len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
            * E is the total number of edges in the batch of graphs

        Returns
        -------
        node_feats : float32 tensor of shape (N, emb_dim)
            Output node representations
        """
        edge_embeds = []
        for i, feats in enumerate(categorical_edge_feats):
            edge_embeds.append(self.edge_embeddings[i](feats))
        edge_embeds = torch.stack(edge_embeds, dim=0).sum(0)
        g = g.local_var()
        g.ndata['feat'] = node_feats
        g.edata['feat'] = edge_embeds
        g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat'))

        node_feats = self.mlp(g.ndata.pop('feat'))
        if self.bn is not None:
            node_feats = self.bn(node_feats)
        if self.activation is not None:
            node_feats = self.activation(node_feats)

        return node_feats
Пример #5
0
    def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph,
                traversal_order, new_node_ids) -> torch.Tensor:
        tg = agg_graph.local_var()
        pg = prop_graph.local_var()

        nfeat = tg.ndata["nfeat"]
        # h_self = nfeat
        h_self = self.encode_time(nfeat, tg.ndata["timestamp"])
        tg.ndata["nfeat"] = h_self
        tg.edata["efeat"] = self.fc_edge(tg.edata["efeat"])
        # efeat = tg.edata["efeat"]
        # tg.apply_edges(lambda edges: {
        #     "efeat":
        #     torch.cat((edges.src["nfeat"], edges.data["efeat"]), dim=1)
        # })
        # tg.edata["efeat"] = self.encode_time(tg.edata["efeat"], tg.edata["timestamp"])
        degs = tg.ndata["degree"]

        # agg_graph aggregation
        if self._agg_type == "pool":
            tg.edata["efeat"] = F.relu(self.fc_pool(tg.edata["efeat"]))
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.max("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        elif self._agg_type in ["mean", "gcn", "lstm"]:
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.sum("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        else:
            raise KeyError("Aggregator type {} not recognized.".format(
                self._agg_type))

        pg.ndata["neigh"] = h_neigh
        # prop_graph propagation
        if False:
            if self._agg_type == "mean":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.max("tmp", "acc"))
                h_neigh = torch.max(h_neigh, pg.ndata["acc"])
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
        else:
            if self._agg_type == "mean":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                h_neighs = [
                    torch.cummax(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]

        if self._agg_type == "gcn":
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        return rst
Пример #6
0
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import Processtest
from dgl.nn import GATConv, SAGEConv, GINConv

# 消息函数和聚合函数
gcn_msg = fn.u_add_e('h', 'w', 'm')
#gcn_msg = fn.copy_src(src='h', out='m')

gcn_reduce = fn.sum(msg='m', out='h')


class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
        with g.local_scope():
            g.ndata['h'] = feature
            g.update_all(gcn_msg, gcn_reduce)
            h = g.ndata['h']
            return self.linear(h)


class Net(nn.Module):
    def __init__(self):