Ejemplo n.º 1
0
    def forward(self, feat, bg):
        # prepare, inputs are of shape V x F, V the number of nodes, F the dim of input features
        self.g = bg
        h = self.feat_drop(feat)
        # V x K x F', K number of heads, F' dim of transformed features
        ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))
        head_ft = ft.transpose(0, 1)  # K x V x F'
        a1 = th.bmm(head_ft, self.attn_l).transpose(0, 1)  # V x K x 1
        a2 = th.bmm(head_ft, self.attn_r).transpose(0, 1)  # V x K x 1
        self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
        # 1. compute edge attention
        self.g.apply_edges(self.edge_attention)
        # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
        self.edge_softmax()
        # 2. compute the aggregated node features scaled by the dropped,
        # unnormalized attention values.
        self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
                          fn.sum('ft', 'ft'))
        # 3. apply normalizer
        ret = self.g.ndata['ft']  # V x K x F'
        ret = ret.flatten(1)

        if self.agg_activation is not None:
            ret = self.agg_activation(ret)

        # Clean ndata and edata
        self.clean_data()

        return ret
Ejemplo n.º 2
0
    def forward(self, inputs):
        # prepare
        h = self.feat_drop(inputs)  # NxD
        ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD
        #ft = self.mlp(ft).reshape((h.shape[0], self.num_heads, -1)) # NxHxD
        head_ft = ft.transpose(0, 1)  # HxNxD'
        a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)  # NxHx1
        a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)  # NxHx1
        self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
        # 1. compute edge attention
        self.g.apply_edges(self.edge_attention)
        # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
        self.edge_softmax()
        # 2. compute the aggregated node features scaled by the dropped,
        # unnormalized attention values.
        self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
                          fn.sum('ft', 'ft'))
        # 3. apply normalizer
        ret = self.g.ndata['ft'] / self.g.ndata['z']  # NxHxD'
        ret = ret.reshape((h.shape[0], -1))
        ret = self.mlp(ret).reshape((h.shape[0], self.num_heads, -1))

        # 4. residual
        if self.residual:
            if self.res_fc is not None:
                resval = self.res_fc(h).reshape(
                    (h.shape[0], self.num_heads, -1))  # NxHxD'
            else:
                resval = torch.unsqueeze(h, 1)  # Nx1xD'
            ret = resval + ret
        return ret
Ejemplo n.º 3
0
 def _pull_nodes(nodes):
     # compute ground truth
     g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
     o1 = g.ndata.pop('o1')
     g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
     o2 = g.ndata.pop('o2')
     g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
     o3 = g.ndata.pop('o3')
     # v2v spmv
     g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
            fn.sum(msg='m1', out='o1'), _afunc)
     assert F.allclose(o1, g.ndata.pop('o1'))
     # v2v fallback to e2v
     g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
            fn.sum(msg='m2', out='o2'), _afunc)
     assert F.allclose(o2, g.ndata.pop('o2'))
Ejemplo n.º 4
0
    def forward(self, inputs):
        # prepare, inputs are of shape V x F, V the number of nodes, F the size of input features
        h = inputs
        if self.feat_drop:
            h = self.feat_drop(h)
        # V x K x F', K number of heads, F' size of transformed features
        ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))
        head_ft = ft.transpose(0, 1)  # K x V x F'
        a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)  # V x K x 1
        a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)  # V x K x 1
        if self.feat_drop:
            ft = self.feat_drop(ft)
        self.g.set_n_repr({'ft': ft, 'a1': a1, 'a2': a2})

        # 1. compute softmax without normalization for edge attention
        self.compute_edge_attention()
        # 2. compute two results, one is the node features scaled by the dropped,
        # unnormalized attention values. Another is the normalizer of the attention values.
        self.g.update_all(
            [fn.src_mul_edge('ft', 'a_drop', 'ft'),
             fn.copy_edge('a', 'a')],
            [fn.sum('ft', 'ft'), fn.sum('a', 'z')])
        # 3. apply normalizer
        ret = self.g.ndata.pop('ft') / self.g.ndata['z']
        # 4. residual
        if self.residual:
            # Note that a broadcasting addition will be employed.
            if self.residual_fc:
                resval = self.residual_fc(h).reshape(
                    (h.shape[0], self.num_heads, -1))
            else:
                resval = h.unsqueeze(1)
            ret = resval + ret
        return ret
Ejemplo n.º 5
0
 def forward(self, g, feature):
     # prepare
     h = self.feat_drop(feature)  # NxD
     ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
     a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1)  # N x H x 1
     a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1)  # N x H x 1
     g.ndata['ft'] = ft
     g.ndata['a1'] = a1
     g.ndata['a2'] = a2
     # 1. compute edge attention
     g.apply_edges(self.edge_attention)
     # 2. compute softmax
     self.edge_softmax(g)
     # 3. compute the aggregated node features scaled by the dropped,
     # unnormalized attention values.
     g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
     ret = g.ndata['ft']
     # 4. residual
     if self.residual:
         if self.res_fc is not None:
             resval = self.res_fc(h).reshape(
                 (h.shape[0], self.num_heads, -1))  # NxHxD'
         else:
             resval = torch.unsqueeze(h, 1)  # Nx1xD'
         ret = resval + ret
     return ret
Ejemplo n.º 6
0
 def forward(self, g, inputs, last=False):
     # prepare
     h = self.feat_drop(inputs)  # NxD
     ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
     head_ft = ft.transpose(0, 1)  # HxNxD'
     a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)  # NxHx1
     a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)  # NxHx1
     g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
     # 1. compute edge attention
     g.apply_edges(self.edge_attention)
     # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
     self.edge_softmax(g)
     # 2. compute the aggregated node features scaled by the dropped,
     # unnormalized attention values.
     g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
     # 3. apply normalizer
     ret = g.ndata['ft'] / g.ndata['z']
     # 4. residual:
     if self.residual:
         if self.res_fc is not None:
             resval = self.res_fc(h).reshape(
                 (h.shape[0], self.num_heads, -1))
         else:
             resval = torch.unsqueeze(h, 1)
         ret = ret + resval
     # 5. batch norm:
     if last == False:
         ret = self.batch_norm(ret.flatten(1))
     else:
         ret = ret.mean(1)
     return ret
Ejemplo n.º 7
0
    def forward(self, inputs, topo):
        # prepare
        h, t = self.feat_drop(inputs), self.feat_drop(topo)  # NxD, N*T
        if not self.last_layer:
            ft = self.fl(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
            ft_c = torch.matmul(torch.cat((h, t), 1), self.fc).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
            ft_q = torch.matmul(h, self.fq).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
            self.g.ndata.update({'ft' : ft, 'ft_c' : ft_c, 'ft_q' : ft_q})
            self.g.apply_edges(self.edge_attention)
            self.edge_softmax()

            l_s = int(0.713*self.g.edata['a_drop'].shape[0])
            topk, _ = torch.topk(self.g.edata['a_drop'], l_s, largest=False, dim=0)
            thd = torch.squeeze(topk[-1])
            self.g.edata['a_drop'] = self.g.edata['a_drop'].squeeze()
            self.g.edata['a_drop'] = torch.where(self.g.edata['a_drop']-thd<0, self.g.edata['a_drop'].new([0.0]), self.g.edata['a_drop'])
            attn_ratio = torch.div((self.g.edata['a_drop'].sum(0).squeeze()+topk.sum(0).squeeze()), self.g.edata['a_drop'].sum(0).squeeze())
            self.g.edata['a_drop'] = self.g.edata['a_drop'] * attn_ratio
            self.g.edata['a_drop'] = self.g.edata['a_drop'].unsqueeze(-1)
            
            self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
            ret = self.g.ndata['ft']
            if self.residual:
                if self.res_fl is not None:
                    resval = self.res_fl(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
                else:
                    resval = torch.unsqueeze(h, 1)  # Nx1xD'
                ret = resval + ret
            ret = torch.cat((ret.flatten(1), ft.mean(1).squeeze()), 1) if self.concat else ret.flatten(1)
        else:
            ret = self.fl(torch.cat((h, t), 1))
        return ret
Ejemplo n.º 8
0
Archivo: gat.py Proyecto: zswzifir/dgl
 def forward(self, inputs):
     # prepare
     h = self.feat_drop(inputs)  # NxD
     ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
     a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims(
         -1)  # N x H x 1
     a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims(
         -1)  # N x H x 1
     self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
     # 1. compute edge attention
     self.g.apply_edges(self.edge_attention)
     # 2. compute softmax
     self.edge_softmax()
     # 3. compute the aggregated node features
     self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
                       fn.sum('ft', 'ft'))
     ret = self.g.ndata['ft']
     # 4. residual
     if self.residual:
         if self.res_fc is not None:
             resval = self.res_fc(h).reshape(
                 (h.shape[0], self.num_heads, -1))  # NxHxD'
         else:
             resval = nd.expand_dims(h, axis=1)  # Nx1xD'
         ret = resval + ret
     return ret
Ejemplo n.º 9
0
    def forward(self, g, node_state_prev):
        node_state = node_state_prev

        # if self.dropout:
        #     node_states = self.dropout(node_state)

        g = g.local_var()

        new_node_states = []

        ## perform weighted convolution for every channel of edge weight
        for c in range(self.num_channels):
            node_state_c = node_state
            if self._out_feats < self._in_feats:
                g.ndata['feat_' + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
            else:
                g.ndata['feat_' + str(c)] = node_state_c
            g.update_all(fn.src_mul_edge('feat_' + str(c), 'feat_' + str(c), 'm'), fn.sum('m', 'feat_' + str(c) + '_new'))
            node_state_c = g.ndata.pop('feat_' + str(c) + '_new')
            if self._out_feats >= self._in_feats:
                node_state_c = torch.mm(node_state_c, self.weight[:, :, c])          
            if self.bias is not None:
                node_state_c = node_state_c + self.bias[:, c]
            node_state_c = self.activation(node_state_c)   
            new_node_states.append(node_state_c) 
        if (self.aggr_mode == 'sum'):
            node_states = torch.stack(new_node_states, dim=1).sum(1)
        elif (self.aggr_mode == 'concat'):
            node_states = torch.cat(new_node_states, dim=1)

        node_states = self.final(node_states)

        return node_states
Ejemplo n.º 10
0
    def forward(self, g, edge_logits, node_feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        edge_logits : float32 tensor of shape (E, 1)
            The edge logits based on which softmax will be performed for weighting
            edges within 1-hop neighborhoods. E represents the number of edges.
        node_feats : float32 tensor of shape (V, node_feat_size)
            Previous node features. V represents the number of nodes.

        Returns
        -------
        float32 tensor of shape (V, node_feat_size)
            Updated node features.
        """
        g = g.local_var()
        g.edata['a'] = edge_softmax(g, edge_logits)
        g.ndata['hv'] = self.project_node(node_feats)

        g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
        context = F.elu(g.ndata['c'])
        return F.relu(self.gru(context, node_feats))
Ejemplo n.º 11
0
    def _test(fld):
        def message_func(edges):
            return {'m': edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m': edges.src[fld] * edges.data['e1']}
            else:
                return {'m': edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld: F.sum(nodes.mailbox['m'], 1)}

        def apply_func(nodes):
            return {fld: 2 * nodes.data[fld]}

        g = generate_graph(idtype)
        # update all
        v1 = g.ndata[fld]
        g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld),
                     apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld: v1})
        g.update_all(message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert F.allclose(v2, v3)
        # update all with edge weights
        v1 = g.ndata[fld]
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                     fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld: v1})
        g.update_all(message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert F.allclose(v2, v4)
Ejemplo n.º 12
0
    def forward(self, batch_complete_graphs, node_feats, feat_sum,
                node_pair_feat):
        """Compute context vectors for each node.

        Parameters
        ----------
        batch_complete_graphs : DGLGraph
            A batch of fully connected graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes.
        feat_sum : float32 tensor of shape (E_full, node_in_feats)
            Sum of node_feats between each pair of nodes. E_full for the number of
            edges in the batch of complete graphs.
        node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
            Input features for each pair of nodes. E_full for the number of edges in
            the batch of complete graphs.

        Returns
        -------
        node_contexts : float32 tensor of shape (V, node_in_feats)
            Context vectors for nodes.
        """
        with batch_complete_graphs.local_scope():
            batch_complete_graphs.ndata['hv'] = node_feats
            batch_complete_graphs.edata['a'] = self.compute_attention(
                self.project_feature_sum(feat_sum) + \
                self.project_node_pair_feature(node_pair_feat)
            )
            batch_complete_graphs.update_all(fn.src_mul_edge('hv', 'a', 'm'),
                                             fn.sum('m', 'context'))
            node_contexts = batch_complete_graphs.ndata.pop('context')

        return node_contexts
Ejemplo n.º 13
0
Archivo: train.py Proyecto: zcrwind/dgl
 def forward(self, inputs):
     # prepare
     h = inputs
     if self.feat_drop:
         h = self.feat_drop(h)
     ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))
     head_ft = ft.transpose(0, 1)
     a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)
     a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)
     if self.feat_drop:
         ft = self.feat_drop(ft)
     self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
     # 1. compute edge attention
     self.g.apply_edges(self.edge_attention)
     # 2. compute two results, one is the node features scaled by the dropped,
     # unnormalized attention values. Another is the normalizer of the attention values.
     self.g.update_all(
         [fn.src_mul_edge('ft', 'a_drop', 'ft'),
          fn.copy_edge('a', 'a')],
         [fn.sum('ft', 'ft'), fn.sum('a', 'z')])
     # 3. apply normalizer
     ret = self.g.ndata['ft'] / self.g.ndata['z']
     # 4. residual
     if self.residual:
         if self.residual_fc:
             ret = self.residual_fc(h) + ret
         else:
             ret = h + ret
     return ret
Ejemplo n.º 14
0
def test_v2v_snr_multi_fn():
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])

    def message_func(edges):
        return {'m2': edges.src['f2']}

    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}

    def reduce_func(nodes):
        return {'v1' : th.sum(nodes.mailbox['m2'], 1)}

    g = generate_graph()
    g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
        'v3' : th.zeros((10, D))})
    fld = 'f2'

    g.send_and_recv((u, v), message_func, reduce_func)
    v1 = g.ndata['v1']

    # 1 message, 2 reduces
    g.send_and_recv((u, v),
            fn.copy_src(src=fld, out='m'),
            [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
            None)
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)

    # send and recv with edge weights, 2 message, 3 reduces
    g.send_and_recv((u, v),
                    [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
                    [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
                    None)
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)

    # run UDF with single message and reduce
    g.send_and_recv((u, v), message_func_edge,
            reduce_func, None)
    v2 = g.ndata['v2']
    assert U.allclose(v1, v2)
Ejemplo n.º 15
0
 def propagate_attention(self, g, eids):
     # Compute attention score
     g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
     g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
     # Send weighted values to target nodes
     g.send_and_recv(eids,
                     [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                     [fn.sum('v', 'wv'), fn.sum('score', 'z')])
Ejemplo n.º 16
0
def test_src_mul_edge():
    # src_mul_edge with all fields
    g = generate_graph()
    g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
    g.register_reduce_func(reducer_both)
    g.update_all()
    assert U.allclose(g.ndata['h'],
                      th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
Ejemplo n.º 17
0
def test_v2v_update_all_multi_fn(idtype):
    def message_func(edges):
        return {'m2': edges.src['f2']}

    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}

    def reduce_func(nodes):
        return {'v1': F.sum(nodes.mailbox['m2'], 1)}

    g = generate_graph(idtype)
    g.ndata.update({'v1': F.zeros((10, )), 'v2': F.zeros((10, ))})
    fld = 'f2'

    g.update_all(message_func, reduce_func)
    v1 = g.ndata['v1']

    # 1 message, 2 reduces
    g.update_all(fn.copy_src(src=fld, out='m'),
                 [fn.sum(msg='m', out='v2'),
                  fn.sum(msg='m', out='v3')])
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)

    # update all with edge weights, 2 message, 3 reduces
    g.update_all([
        fn.src_mul_edge(src=fld, edge='e1', out='m1'),
        fn.src_mul_edge(src=fld, edge='e2', out='m2')
    ], [
        fn.sum(msg='m1', out='v1'),
        fn.sum(msg='m2', out='v2'),
        fn.sum(msg='m1', out='v3')
    ], None)
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)

    # run UDF with single message and reduce
    g.update_all(message_func_edge, reduce_func, None)
    v2 = g.ndata['v2']
    assert F.allclose(v1, v2)
Ejemplo n.º 18
0
    def forward(self, g):

        g.update_all(fn.src_mul_edge('node_feats', 'edge_feats', 'msg'),
                     fn.sum('msg', 'reduced'))
        g.ndata['node_feats'] = self.linear(
            torch.cat((g.ndata['node_feats'], g.ndata['reduced']), dim=-1))
        if self.activation is not None:
            g.ndata['node_feats'] = self.activation(g.ndata['node_feats'])
        return g
Ejemplo n.º 19
0
 def forward(self, feat):
     g = self.graph.local_var()
     g.ndata['h'] = feat.mm(getattr(self, 'W'))
     g.update_all(fn.src_mul_edge(src='h', edge='w', out='m'),
                  fn.sum(msg='m', out='h'))
     rst = g.ndata['h']
     #rst = self.linear(rst)
     rst = self.activation(rst)
     return rst
Ejemplo n.º 20
0
    def __init__(self, in_feats, out_feats, last=False):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.last = last
        """multiply src with edge data or not"""
        # self.msg_func = fn.copy_src(src='h', out='m')
        self.msg_func = fn.src_mul_edge(src='h', edge='w', out='m')

        self.reduce_func = fn.sum(msg='m', out='h')
Ejemplo n.º 21
0
 def propagate_attention(self, g):
     # Compute attention score
     g.apply_edges(src_dot_dst('k', 'q', 'score'))
     g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
     # Update node state
     g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
     g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'),
                  div_by_z('wv', 'z', 'o'))
     out_x = g.nodes['schema'].data['o']
     return out_x
Ejemplo n.º 22
0
    def propagate_attention(self, g):
        # Compute attention score
        g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'))  #, edges)
        g.apply_edges(scaled_exp('score', np.sqrt(self.out_dim)))

        # Send weighted values to target nodes
        eids = g.edges()
        g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score', 'V_h'),
                        fn.sum('V_h', 'wV'))
        g.send_and_recv(eids, fn.copy_edge('score', 'score'),
                        fn.sum('score', 'z'))
Ejemplo n.º 23
0
 def forward(self, x):
     x = torch.matmul(x, self.weight)
     x = x.reshape((x.size(0), self.heads, -1))  # NxHxD'
     head_x = x.transpose(0, 1)  # HxNxD'
     a1 = torch.bmm(head_x, self.att_l).transpose(0, 1)  # NxHx1
     a2 = torch.bmm(head_x, self.att_r).transpose(0, 1)  # NxHx1
     self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})
     self.g.apply_edges(self.edge_attention)
     self.edge_softmax()
     self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))
     x = self.g.ndata['x'] / self.g.ndata['z']  # NxHxD'
     return x.view(-1, self.heads * self.out_channels)
Ejemplo n.º 24
0
 def forward(self, x):
     ft = self.fc(x).reshape((x.shape[0], self.heads, -1))  # NxHxD'
     head_ft = ft.transpose(0, 1)  # HxNxD'
     a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)  # NxHx1
     a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)  # NxHx1
     self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2})
     self.g.apply_edges(self.edge_attention)
     self.edge_softmax()
     self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
                       fn.sum('ft', 'ft'))
     ret = self.g.ndata['ft'] / self.g.ndata['z']  # NxHxD'
     return ret.view(-1, self.heads * self.out_channels)
Ejemplo n.º 25
0
    def _test(fld):
        def message_func(edges):
            return {'m': edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m': edges.src[fld] * edges.data['e1']}
            else:
                return {'m': edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld: F.sum(nodes.mailbox['m'], 1)}

        def apply_func(nodes):
            return {fld: 2 * nodes.data[fld]}

        g = generate_graph()
        # send and recv
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'),
                        fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert F.allclose(v2, v3)
        # send and recv with edge weights
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
                        fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e2', out='m'),
                        fn.sum(msg='m', out=fld), apply_func)
        v3 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert F.allclose(v2, v3)
        assert F.allclose(v3, v4)
Ejemplo n.º 26
0
    def _test(fld):
        def message_func(edges):
            return {'m': edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m': edges.src[fld] * edges.data['e1']}
            else:
                return {'m': edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld: mx.nd.max(nodes.mailbox['m'], axis=1)}

        def apply_func(nodes):
            return {fld: 2 * nodes.data[fld]}

        g = simple_graph()
        # update all
        v1 = g.ndata[fld]
        g.update_all(fn.copy_src(src=fld, out='m'), fn.max(msg='m', out=fld),
                     apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.update_all(message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
        # update all with edge weights
        v1 = g.ndata[fld]
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                     fn.max(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
                     fn.max(msg='m', out=fld), apply_func)
        v3 = g.ndata[fld].squeeze()
        g.set_n_repr({fld: v1})
        g.update_all(message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
        assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
Ejemplo n.º 27
0
 def forward(self, g):
     alpha_prime = self.leaky_relu(self.attn(g.edata[self.attn_key]))
     # Magic part is multiplying attention weights with the edge embedding
     g.edata['a'] = dglnn.edge_softmax(
         g, alpha_prime) * g.edata['emb'].view(g.edata['emb'].shape[0],
                                               self.n_heads, -1)
     attn_emb = g.ndata[self.msg_key]
     if attn_emb.ndimension() == 2:
         g.ndata[self.msg_key] = attn_emb.view(g.number_of_nodes(),
                                               self.n_heads, -1)
     g.update_all(fn.src_mul_edge(self.msg_key, 'a', 'm'),
                  fn.sum('m', 'emb'))
     return GraphLambda(lambda x: x.view(x.shape[0], -1))(g)
Ejemplo n.º 28
0
 def propagate_attention(self, g, eids):
     # Compute attention score
     g.apply_edges(src_dot_dst("k", "q", "score"), eids)
     g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
     # Send weighted values to target nodes
     g.send_and_recv(
         eids,
         [
             fn.src_mul_edge("v", "score", "v"),
             fn.copy_edge("score", "score")
         ],
         [fn.sum("v", "wv"), fn.sum("score", "z")],
     )
Ejemplo n.º 29
0
 def propagate_attention(self, g, eids, per_head=False):
     # Compute attention score
     if per_head:
         for i in range(0, len(per_head)):
             # This sends in the edges per head.
             score_key = 'score{}'.format(i)
             g.apply_edges(src_dot_dst('k', 'q', score_key, i), per_head[i])
             g.apply_edges(scaled_exp(score_key, np.sqrt(self.d_k)),
                           per_head[i])
             # Send weighted values to target nodes
             g.send_and_recv(per_head[i], [
                 fn.src_mul_edge('v', score_key, 'v'),
                 fn.copy_edge(score_key, score_key)
             ], [fn.sum('v', 'wv'),
                 fn.sum(score_key, 'z')])
     else:
         g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
         g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
         # Send weighted values to target nodes
         g.send_and_recv(eids, [
             fn.src_mul_edge('v', 'score', 'v'),
             fn.copy_edge('score', 'score')
         ], [fn.sum('v', 'wv'), fn.sum('score', 'z')])
Ejemplo n.º 30
0
def test_src_mul_edge():
    # src_mul_edge with all fields
    g = generate_graph()
    g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
    g.register_reduce_func(reducer_both)
    # test with update_all
    g.update_all()
    assert F.allclose(g.ndata.pop('out'),
                      F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
    # test with send and then recv
    g.send()
    g.recv()
    assert F.allclose(g.ndata.pop('out'),
                      F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))