Beispiel #1
0
    def topk_attention(self, graph: DGLGraph):
        graph = graph.local_var()  # the graph should be added a self-loop edge

        def send_edge_message(edges):
            return {'m_e': edges.data['e']}

        def topk_attn_reduce_func(nodes):
            topk = self.top_k
            attentions = nodes.mailbox['m_e']
            neighbor_num = attentions.shape[1]
            if topk > neighbor_num:
                topk = neighbor_num
            topk_atts, _ = torch.topk(attentions, k=topk, dim=1)
            kth_attn_value = topk_atts[:, topk - 1]
            return {'kth_e': kth_attn_value}

        graph.register_reduce_func(topk_attn_reduce_func)
        graph.register_message_func(send_edge_message)
        graph.update_all(message_func=send_edge_message,
                         reduce_func=topk_attn_reduce_func)

        def edge_score_update(edges):
            scores, kth_score = edges.data['e'], edges.dst['kth_e']
            scores[scores < kth_score] = self.attention_mask_value
            return {'e': scores}

        graph.apply_edges(edge_score_update)
        topk_attentions = graph.edata.pop('e')
        return topk_attentions
Beispiel #2
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        graph.apply_edges(self.input_edge_func, etype='bond')

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)
Beispiel #3
0
 def forward(self, graph: dgl.DGLGraph, h):
     # equation (1)
     z = self.fc(h)
     graph.ndata['z'] = z
     # equation (2)
     graph.apply_edges(self.edge_attention)
     # equation (3) & (4)
     graph.update_all(self.message_func, self.reduce_func)
     return graph.ndata.pop('h')
Beispiel #4
0
    def topk_attention_softmax(self, graph: DGLGraph):
        graph = graph.local_var()

        def send_edge_message(edges):
            return {'m_e': edges.data['e'], 'm_e_id': edges.data['e_id']}

        def topk_attn_reduce_func(nodes):
            topk = self.top_k
            attentions = nodes.mailbox['m_e']
            edge_ids = nodes.mailbox['m_e_id']
            topk_edge_ids = torch.full(size=(edge_ids.shape[0], topk),
                                       fill_value=-1,
                                       dtype=torch.long)
            if torch.cuda.is_available():
                topk_edge_ids = topk_edge_ids.cuda()
            attentions_sum = attentions.sum(dim=2)
            neighbor_num = attentions_sum.shape[1]
            if topk > neighbor_num:
                topk = neighbor_num
            topk_atts, top_k_neighbor_idx = torch.topk(attentions_sum,
                                                       k=topk,
                                                       dim=1)
            top_k_neighbor_idx = top_k_neighbor_idx.squeeze(dim=-1)
            row_idxes = torch.arange(0,
                                     top_k_neighbor_idx.shape[0]).view(-1, 1)
            top_k_attention = attentions[row_idxes, top_k_neighbor_idx]
            top_k_edge_ids = edge_ids[row_idxes, top_k_neighbor_idx]
            top_k_attention_norm = top_k_attention.sum(dim=1)
            topk_edge_ids[:, torch.arange(0, topk)] = top_k_edge_ids
            return {
                'topk_eid': topk_edge_ids,
                'topk_norm': top_k_attention_norm
            }

        graph.register_reduce_func(topk_attn_reduce_func)
        graph.register_message_func(send_edge_message)
        graph.update_all(message_func=send_edge_message,
                         reduce_func=topk_attn_reduce_func)
        topk_edge_ids = graph.ndata['topk_eid'].flatten()
        topk_edge_ids = topk_edge_ids[topk_edge_ids >= 0]
        mask_edges = torch.zeros((graph.number_of_edges(), 1))
        if torch.cuda.is_available():
            mask_edges = mask_edges.cuda()
        mask_edges[topk_edge_ids] = 1
        attentions = graph.edata['e'].squeeze(dim=-1)
        attentions = attentions * mask_edges
        graph.edata['e'] = attentions.unsqueeze(dim=-1)

        def edge_score_update(edges):
            scores = edges.data['e'] / edges.dst['topk_norm']
            return {'e': scores}

        graph.apply_edges(edge_score_update)
        topk_attentions = graph.edata.pop('e')
        return topk_attentions
Beispiel #5
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)
        return graph.ndata['feat']
Beispiel #6
0
    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:

        with graph.local_scope():

            graph.apply_edges(self._apply_edges, etype="forward")
            h_forward = graph.edges["forward"].data["h"]

        with graph.local_scope():

            graph.apply_edges(self._apply_edges, etype="reverse")
            h_reverse = graph.edges["reverse"].data["h"]

        return self.layers(h_forward) + self.layers(h_reverse)
Beispiel #7
0
    def forward(self, graph: DGLGraph, features, drop_edge_ids=None):
        ###Attention computation: pre-normalization structure
        graph = graph.local_var()
        h = self.graph_norm(features)
        # feat_head = self.fc_head(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        # feat_tail = self.fc_tail(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        feat_head = torch.tanh(self.fc_head(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        feat_tail = torch.tanh(self.fc_tail(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        # feat_head = F.relu(self.fc_head(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        # feat_tail = F.relu(self.fc_tail(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        feat = self.fc(self.feat_drop(h)).view(-1, self._num_heads,
                                               self._att_dim)
        eh = (feat_head * self.attn_h).sum(dim=-1).unsqueeze(-1)
        et = (feat_tail * self.attn_t).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({'ft': feat, 'eh': eh, 'et': et})
        graph.apply_edges(fn.u_add_v('eh', 'et', 'e'))
        attations = graph.edata.pop('e')
        attations = self.leaky_relu(attations)
        if drop_edge_ids is not None:
            attations[drop_edge_ids] = self.attention_mask_value

        if self.top_k <= 0:
            graph.edata['a'] = edge_softmax(graph, attations)
        else:
            if self.topk_type == 'local':
                graph.edata['e'] = attations
                attations = self.topk_attention(graph)
                graph.edata['a'] = edge_softmax(
                    graph, attations)  ##return attention scores
            else:
                graph.edata['e'] = edge_softmax(graph, attations)
                graph.edata['a'] = self.topk_attention_softmax(graph)

        rst = self.ppr_estimation(graph=graph)
        rst = rst.flatten(1)
        rst = self.fc_out(rst)
        resval = self.res_fc(features)
        rst = resval + self.feat_drop(rst)

        rst_ff = self.feed_forward(self.ff_norm(rst))
        rst = rst + self.feat_drop(rst_ff)
        # +++++++
        attations = graph.edata.pop('a')
        # +++++++
        return rst, attations
    def forward(self, graph: dgl.DGLGraph):
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        graph.update_all(message_func=self.message_function,
                         reduce_func=self.reduce_func(msg='m', out='m_sum'))

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Beispiel #9
0
    def forward(self, graph: dgl.DGLGraph):
        graph.ndata['feat'] = self.node_embedding[None, :].expand(
            graph.number_of_nodes(), -1)

        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Beispiel #10
0
class D2GCN(nn.Module):
    def __init__(self, in_feat_dim, out_feat_dim):
        super(D2GCN, self).__init__()
        self.fedge = nn.Sequential(
            nn.Linear(in_feat_dim * 2, in_feat_dim // 64),
            nn.BatchNorm1d(in_feat_dim // 64), nn.Dropout(dropout),
            nn.LeakyReLU(), nn.Linear(in_feat_dim // 64, out_feat_dim),
            nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU())
        if feature_drop:
            self.feat_drop = nn.Dropout(feature_drop)
        else:
            self.feat_drop = lambda x: x
        if att_drop:
            self.att_drop = nn.Dropout(att_drop)
        else:
            self.att_drop = lambda x: x
        self.attn_l = nn.Parameter(torch.Tensor(size=(1, out_feat_dim)))
        self.attn_r = nn.Parameter(torch.Tensor(size=(1, out_feat_dim)))

        self.relu = nn.LeakyReLU(alpha)
        self.softmax = edge_softmax
        self.fnode = nn.Sequential(
            nn.Linear(in_feat_dim + out_feat_dim, out_feat_dim // 64),
            nn.BatchNorm1d(out_feat_dim // 64), nn.Dropout(dropout),
            nn.LeakyReLU(), nn.Linear(out_feat_dim // 64, out_feat_dim),
            nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU())

        nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_r.data, gain=1.414)

    def build_graph(self, num_nodes, device):
        self.g = DGLGraph()
        self.g.add_nodes(num_nodes)
        for i in range(0, num_nodes):
            for j in range(0, num_nodes):
                if i != j:
                    self.g.add_edge(i, j)
                    self.g.add_edge(j, i)

        self.g.to(device)
        self.g.register_message_func(self.send_source)
        self.g.register_reduce_func(self.simple_reduce)

    def send_source(self, edges):
        edge_feature = self.fedge.forward(
            torch.cat((edges.src["h"], edges.dst["h"]), dim=1))
        msg = self.fnode.forward(
            torch.cat((edges.src["h"], edge_feature), dim=1))
        m = torch.mul(msg, edges.data['a_drop'])
        return {"m": m}

    def simple_reduce(self, nodes):
        return {"h": torch.sum(nodes.mailbox['m'], dim=1) + nodes.data["h"]}

    def edge_attention(self, edges):
        a = self.relu(edges.src['a1'] + edges.dst['a2'])
        return {'a': a}

    def edge_softmax(self):
        att = self.softmax(self.g, self.g.edata.pop('a'))
        self.g.edata['a_drop'] = self.att_drop(att)

    def forward(self, n_feature):
        a1 = (n_feature * self.attn_l).sum(dim=-1).unsqueeze(-1)
        a2 = (n_feature * self.attn_r).sum(dim=-1).unsqueeze(-1)
        self.g.ndata.update({'h': n_feature, 'a1': a1, 'a2': a2})
        self.g.apply_edges(self.edge_attention)
        self.edge_softmax()
        self.g.send(self.g.edges())
        self.g.recv(self.g.nodes())
        return self.g.ndata.pop('h')
Beispiel #11
0
 def forward(self, dgl_data: dgl.DGLGraph):
     dgl_data.apply_edges(self.edge_init)
     return dgl_data
Beispiel #12
0
 def forward(self, dgl_data: dgl.DGLGraph):
     # 注意这个函数只会更新node feature
     dgl_data.update_all(self.gcn_msg, self.gcn_reduce)
     dgl_data.apply_edges(self.edge_update)
     return dgl_data