def forward(self, feat, bg): # prepare, inputs are of shape V x F, V the number of nodes, F the dim of input features self.g = bg h = self.feat_drop(feat) # V x K x F', K number of heads, F' dim of transformed features ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) head_ft = ft.transpose(0, 1) # K x V x F' a1 = th.bmm(head_ft, self.attn_l).transpose(0, 1) # V x K x 1 a2 = th.bmm(head_ft, self.attn_r).transpose(0, 1) # V x K x 1 self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) self.edge_softmax() # 2. compute the aggregated node features scaled by the dropped, # unnormalized attention values. self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) # 3. apply normalizer ret = self.g.ndata['ft'] # V x K x F' ret = ret.flatten(1) if self.agg_activation is not None: ret = self.agg_activation(ret) # Clean ndata and edata self.clean_data() return ret
def forward(self, inputs): # prepare h = self.feat_drop(inputs) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD #ft = self.mlp(ft).reshape((h.shape[0], self.num_heads, -1)) # NxHxD head_ft = ft.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) self.edge_softmax() # 2. compute the aggregated node features scaled by the dropped, # unnormalized attention values. self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) # 3. apply normalizer ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' ret = ret.reshape((h.shape[0], -1)) ret = self.mlp(ret).reshape((h.shape[0], self.num_heads, -1)) # 4. residual if self.residual: if self.res_fc is not None: resval = self.res_fc(h).reshape( (h.shape[0], self.num_heads, -1)) # NxHxD' else: resval = torch.unsqueeze(h, 1) # Nx1xD' ret = resval + ret return ret
def _pull_nodes(nodes): # compute ground truth g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc) o1 = g.ndata.pop('o1') g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc) o2 = g.ndata.pop('o2') g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc) o3 = g.ndata.pop('o3') # v2v spmv g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.sum(msg='m1', out='o1'), _afunc) assert F.allclose(o1, g.ndata.pop('o1')) # v2v fallback to e2v g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'), fn.sum(msg='m2', out='o2'), _afunc) assert F.allclose(o2, g.ndata.pop('o2'))
def forward(self, inputs): # prepare, inputs are of shape V x F, V the number of nodes, F the size of input features h = inputs if self.feat_drop: h = self.feat_drop(h) # V x K x F', K number of heads, F' size of transformed features ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) head_ft = ft.transpose(0, 1) # K x V x F' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # V x K x 1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # V x K x 1 if self.feat_drop: ft = self.feat_drop(ft) self.g.set_n_repr({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute softmax without normalization for edge attention self.compute_edge_attention() # 2. compute two results, one is the node features scaled by the dropped, # unnormalized attention values. Another is the normalizer of the attention values. self.g.update_all( [fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], [fn.sum('ft', 'ft'), fn.sum('a', 'z')]) # 3. apply normalizer ret = self.g.ndata.pop('ft') / self.g.ndata['z'] # 4. residual if self.residual: # Note that a broadcasting addition will be employed. if self.residual_fc: resval = self.residual_fc(h).reshape( (h.shape[0], self.num_heads, -1)) else: resval = h.unsqueeze(1) ret = resval + ret return ret
def forward(self, g, feature): # prepare h = self.feat_drop(feature) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1) # N x H x 1 a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1 g.ndata['ft'] = ft g.ndata['a1'] = a1 g.ndata['a2'] = a2 # 1. compute edge attention g.apply_edges(self.edge_attention) # 2. compute softmax self.edge_softmax(g) # 3. compute the aggregated node features scaled by the dropped, # unnormalized attention values. g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) ret = g.ndata['ft'] # 4. residual if self.residual: if self.res_fc is not None: resval = self.res_fc(h).reshape( (h.shape[0], self.num_heads, -1)) # NxHxD' else: resval = torch.unsqueeze(h, 1) # Nx1xD' ret = resval + ret return ret
def forward(self, g, inputs, last=False): # prepare h = self.feat_drop(inputs) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' head_ft = ft.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute edge attention g.apply_edges(self.edge_attention) # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) self.edge_softmax(g) # 2. compute the aggregated node features scaled by the dropped, # unnormalized attention values. g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) # 3. apply normalizer ret = g.ndata['ft'] / g.ndata['z'] # 4. residual: if self.residual: if self.res_fc is not None: resval = self.res_fc(h).reshape( (h.shape[0], self.num_heads, -1)) else: resval = torch.unsqueeze(h, 1) ret = ret + resval # 5. batch norm: if last == False: ret = self.batch_norm(ret.flatten(1)) else: ret = ret.mean(1) return ret
def forward(self, inputs, topo): # prepare h, t = self.feat_drop(inputs), self.feat_drop(topo) # NxD, N*T if not self.last_layer: ft = self.fl(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' ft_c = torch.matmul(torch.cat((h, t), 1), self.fc).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' ft_q = torch.matmul(h, self.fq).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' self.g.ndata.update({'ft' : ft, 'ft_c' : ft_c, 'ft_q' : ft_q}) self.g.apply_edges(self.edge_attention) self.edge_softmax() l_s = int(0.713*self.g.edata['a_drop'].shape[0]) topk, _ = torch.topk(self.g.edata['a_drop'], l_s, largest=False, dim=0) thd = torch.squeeze(topk[-1]) self.g.edata['a_drop'] = self.g.edata['a_drop'].squeeze() self.g.edata['a_drop'] = torch.where(self.g.edata['a_drop']-thd<0, self.g.edata['a_drop'].new([0.0]), self.g.edata['a_drop']) attn_ratio = torch.div((self.g.edata['a_drop'].sum(0).squeeze()+topk.sum(0).squeeze()), self.g.edata['a_drop'].sum(0).squeeze()) self.g.edata['a_drop'] = self.g.edata['a_drop'] * attn_ratio self.g.edata['a_drop'] = self.g.edata['a_drop'].unsqueeze(-1) self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) ret = self.g.ndata['ft'] if self.residual: if self.res_fl is not None: resval = self.res_fl(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' else: resval = torch.unsqueeze(h, 1) # Nx1xD' ret = resval + ret ret = torch.cat((ret.flatten(1), ft.mean(1).squeeze()), 1) if self.concat else ret.flatten(1) else: ret = self.fl(torch.cat((h, t), 1)) return ret
def forward(self, inputs): # prepare h = self.feat_drop(inputs) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims( -1) # N x H x 1 a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims( -1) # N x H x 1 self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute softmax self.edge_softmax() # 3. compute the aggregated node features self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) ret = self.g.ndata['ft'] # 4. residual if self.residual: if self.res_fc is not None: resval = self.res_fc(h).reshape( (h.shape[0], self.num_heads, -1)) # NxHxD' else: resval = nd.expand_dims(h, axis=1) # Nx1xD' ret = resval + ret return ret
def forward(self, g, node_state_prev): node_state = node_state_prev # if self.dropout: # node_states = self.dropout(node_state) g = g.local_var() new_node_states = [] ## perform weighted convolution for every channel of edge weight for c in range(self.num_channels): node_state_c = node_state if self._out_feats < self._in_feats: g.ndata['feat_' + str(c)] = torch.mm(node_state_c, self.weight[:, :, c]) else: g.ndata['feat_' + str(c)] = node_state_c g.update_all(fn.src_mul_edge('feat_' + str(c), 'feat_' + str(c), 'm'), fn.sum('m', 'feat_' + str(c) + '_new')) node_state_c = g.ndata.pop('feat_' + str(c) + '_new') if self._out_feats >= self._in_feats: node_state_c = torch.mm(node_state_c, self.weight[:, :, c]) if self.bias is not None: node_state_c = node_state_c + self.bias[:, c] node_state_c = self.activation(node_state_c) new_node_states.append(node_state_c) if (self.aggr_mode == 'sum'): node_states = torch.stack(new_node_states, dim=1).sum(1) elif (self.aggr_mode == 'concat'): node_states = torch.cat(new_node_states, dim=1) node_states = self.final(node_states) return node_states
def forward(self, g, edge_logits, node_feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs edge_logits : float32 tensor of shape (E, 1) The edge logits based on which softmax will be performed for weighting edges within 1-hop neighborhoods. E represents the number of edges. node_feats : float32 tensor of shape (V, node_feat_size) Previous node features. V represents the number of nodes. Returns ------- float32 tensor of shape (V, node_feat_size) Updated node features. """ g = g.local_var() g.edata['a'] = edge_softmax(g, edge_logits) g.ndata['hv'] = self.project_node(node_feats) g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c')) context = F.elu(g.ndata['c']) return F.relu(self.gru(context, node_feats))
def _test(fld): def message_func(edges): return {'m': edges.src[fld]} def message_func_edge(edges): if len(edges.src[fld].shape) == 1: return {'m': edges.src[fld] * edges.data['e1']} else: return {'m': edges.src[fld] * edges.data['e2']} def reduce_func(nodes): return {fld: F.sum(nodes.mailbox['m'], 1)} def apply_func(nodes): return {fld: 2 * nodes.data[fld]} g = generate_graph(idtype) # update all v1 = g.ndata[fld] g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.ndata.update({fld: v1}) g.update_all(message_func, reduce_func, apply_func) v3 = g.ndata[fld] assert F.allclose(v2, v3) # update all with edge weights v1 = g.ndata[fld] g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'), fn.sum(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.ndata.update({fld: v1}) g.update_all(message_func_edge, reduce_func, apply_func) v4 = g.ndata[fld] assert F.allclose(v2, v4)
def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat): """Compute context vectors for each node. Parameters ---------- batch_complete_graphs : DGLGraph A batch of fully connected graphs. node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes. feat_sum : float32 tensor of shape (E_full, node_in_feats) Sum of node_feats between each pair of nodes. E_full for the number of edges in the batch of complete graphs. node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats) Input features for each pair of nodes. E_full for the number of edges in the batch of complete graphs. Returns ------- node_contexts : float32 tensor of shape (V, node_in_feats) Context vectors for nodes. """ with batch_complete_graphs.local_scope(): batch_complete_graphs.ndata['hv'] = node_feats batch_complete_graphs.edata['a'] = self.compute_attention( self.project_feature_sum(feat_sum) + \ self.project_node_pair_feature(node_pair_feat) ) batch_complete_graphs.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context')) node_contexts = batch_complete_graphs.ndata.pop('context') return node_contexts
def forward(self, inputs): # prepare h = inputs if self.feat_drop: h = self.feat_drop(h) ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) head_ft = ft.transpose(0, 1) a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) if self.feat_drop: ft = self.feat_drop(ft) self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute two results, one is the node features scaled by the dropped, # unnormalized attention values. Another is the normalizer of the attention values. self.g.update_all( [fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], [fn.sum('ft', 'ft'), fn.sum('a', 'z')]) # 3. apply normalizer ret = self.g.ndata['ft'] / self.g.ndata['z'] # 4. residual if self.residual: if self.residual_fc: ret = self.residual_fc(h) + ret else: ret = h + ret return ret
def test_v2v_snr_multi_fn(): u = th.tensor([0, 0, 0, 3, 4, 9]) v = th.tensor([1, 2, 3, 9, 9, 0]) def message_func(edges): return {'m2': edges.src['f2']} def message_func_edge(edges): return {'m2': edges.src['f2'] * edges.data['e2']} def reduce_func(nodes): return {'v1' : th.sum(nodes.mailbox['m2'], 1)} g = generate_graph() g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)), 'v3' : th.zeros((10, D))}) fld = 'f2' g.send_and_recv((u, v), message_func, reduce_func) v1 = g.ndata['v1'] # 1 message, 2 reduces g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')], None) v2 = g.ndata['v2'] v3 = g.ndata['v3'] assert U.allclose(v1, v2) assert U.allclose(v1, v3) # send and recv with edge weights, 2 message, 3 reduces g.send_and_recv((u, v), [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')], None) v1 = g.ndata['v1'] v2 = g.ndata['v2'] v3 = g.ndata['v3'] assert U.allclose(v1, v2) assert U.allclose(v1, v3) # run UDF with single message and reduce g.send_and_recv((u, v), message_func_edge, reduce_func, None) v2 = g.ndata['v2'] assert U.allclose(v1, v2)
def propagate_attention(self, g, eids): # Compute attention score g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) # Send weighted values to target nodes g.send_and_recv(eids, [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')], [fn.sum('v', 'wv'), fn.sum('score', 'z')])
def test_src_mul_edge(): # src_mul_edge with all fields g = generate_graph() g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) g.register_reduce_func(reducer_both) g.update_all() assert U.allclose(g.ndata['h'], th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
def test_v2v_update_all_multi_fn(idtype): def message_func(edges): return {'m2': edges.src['f2']} def message_func_edge(edges): return {'m2': edges.src['f2'] * edges.data['e2']} def reduce_func(nodes): return {'v1': F.sum(nodes.mailbox['m2'], 1)} g = generate_graph(idtype) g.ndata.update({'v1': F.zeros((10, )), 'v2': F.zeros((10, ))}) fld = 'f2' g.update_all(message_func, reduce_func) v1 = g.ndata['v1'] # 1 message, 2 reduces g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')]) v2 = g.ndata['v2'] v3 = g.ndata['v3'] assert F.allclose(v1, v2) assert F.allclose(v1, v3) # update all with edge weights, 2 message, 3 reduces g.update_all([ fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2') ], [ fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3') ], None) v1 = g.ndata['v1'] v2 = g.ndata['v2'] v3 = g.ndata['v3'] assert F.allclose(v1, v2) assert F.allclose(v1, v3) # run UDF with single message and reduce g.update_all(message_func_edge, reduce_func, None) v2 = g.ndata['v2'] assert F.allclose(v1, v2)
def forward(self, g): g.update_all(fn.src_mul_edge('node_feats', 'edge_feats', 'msg'), fn.sum('msg', 'reduced')) g.ndata['node_feats'] = self.linear( torch.cat((g.ndata['node_feats'], g.ndata['reduced']), dim=-1)) if self.activation is not None: g.ndata['node_feats'] = self.activation(g.ndata['node_feats']) return g
def forward(self, feat): g = self.graph.local_var() g.ndata['h'] = feat.mm(getattr(self, 'W')) g.update_all(fn.src_mul_edge(src='h', edge='w', out='m'), fn.sum(msg='m', out='h')) rst = g.ndata['h'] #rst = self.linear(rst) rst = self.activation(rst) return rst
def __init__(self, in_feats, out_feats, last=False): super(GCNLayer, self).__init__() self.linear = nn.Linear(in_feats, out_feats) self.last = last """multiply src with edge data or not""" # self.msg_func = fn.copy_src(src='h', out='m') self.msg_func = fn.src_mul_edge(src='h', edge='w', out='m') self.reduce_func = fn.sum(msg='m', out='h')
def propagate_attention(self, g): # Compute attention score g.apply_edges(src_dot_dst('k', 'q', 'score')) g.apply_edges(scaled_exp('score', math.sqrt(self.d_k))) # Update node state g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv')) g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o')) out_x = g.nodes['schema'].data['o'] return out_x
def propagate_attention(self, g): # Compute attention score g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score')) #, edges) g.apply_edges(scaled_exp('score', np.sqrt(self.out_dim))) # Send weighted values to target nodes eids = g.edges() g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV')) g.send_and_recv(eids, fn.copy_edge('score', 'score'), fn.sum('score', 'z'))
def forward(self, x): x = torch.matmul(x, self.weight) x = x.reshape((x.size(0), self.heads, -1)) # NxHxD' head_x = x.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_x, self.att_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_x, self.att_r).transpose(0, 1) # NxHx1 self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2}) self.g.apply_edges(self.edge_attention) self.edge_softmax() self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x')) x = self.g.ndata['x'] / self.g.ndata['z'] # NxHxD' return x.view(-1, self.heads * self.out_channels)
def forward(self, x): ft = self.fc(x).reshape((x.shape[0], self.heads, -1)) # NxHxD' head_ft = ft.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 self.g.ndata.update({'ft': ft, 'a1': a1, 'a2': a2}) self.g.apply_edges(self.edge_attention) self.edge_softmax() self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' return ret.view(-1, self.heads * self.out_channels)
def _test(fld): def message_func(edges): return {'m': edges.src[fld]} def message_func_edge(edges): if len(edges.src[fld].shape) == 1: return {'m': edges.src[fld] * edges.data['e1']} else: return {'m': edges.src[fld] * edges.data['e2']} def reduce_func(nodes): return {fld: F.sum(nodes.mailbox['m'], 1)} def apply_func(nodes): return {fld: 2 * nodes.data[fld]} g = generate_graph() # send and recv v1 = g.ndata[fld] g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.send_and_recv((u, v), message_func, reduce_func, apply_func) v3 = g.ndata[fld] assert F.allclose(v2, v3) # send and recv with edge weights v1 = g.ndata[fld] g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'), fn.sum(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e2', out='m'), fn.sum(msg='m', out=fld), apply_func) v3 = g.ndata[fld] g.set_n_repr({fld: v1}) g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func) v4 = g.ndata[fld] assert F.allclose(v2, v3) assert F.allclose(v3, v4)
def _test(fld): def message_func(edges): return {'m': edges.src[fld]} def message_func_edge(edges): if len(edges.src[fld].shape) == 1: return {'m': edges.src[fld] * edges.data['e1']} else: return {'m': edges.src[fld] * edges.data['e2']} def reduce_func(nodes): return {fld: mx.nd.max(nodes.mailbox['m'], axis=1)} def apply_func(nodes): return {fld: 2 * nodes.data[fld]} g = simple_graph() # update all v1 = g.ndata[fld] g.update_all(fn.copy_src(src=fld, out='m'), fn.max(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.update_all(message_func, reduce_func, apply_func) v3 = g.ndata[fld] assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05) # update all with edge weights v1 = g.ndata[fld] g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'), fn.max(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'), fn.max(msg='m', out=fld), apply_func) v3 = g.ndata[fld].squeeze() g.set_n_repr({fld: v1}) g.update_all(message_func_edge, reduce_func, apply_func) v4 = g.ndata[fld] assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05) assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
def forward(self, g): alpha_prime = self.leaky_relu(self.attn(g.edata[self.attn_key])) # Magic part is multiplying attention weights with the edge embedding g.edata['a'] = dglnn.edge_softmax( g, alpha_prime) * g.edata['emb'].view(g.edata['emb'].shape[0], self.n_heads, -1) attn_emb = g.ndata[self.msg_key] if attn_emb.ndimension() == 2: g.ndata[self.msg_key] = attn_emb.view(g.number_of_nodes(), self.n_heads, -1) g.update_all(fn.src_mul_edge(self.msg_key, 'a', 'm'), fn.sum('m', 'emb')) return GraphLambda(lambda x: x.view(x.shape[0], -1))(g)
def propagate_attention(self, g, eids): # Compute attention score g.apply_edges(src_dot_dst("k", "q", "score"), eids) g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids) # Send weighted values to target nodes g.send_and_recv( eids, [ fn.src_mul_edge("v", "score", "v"), fn.copy_edge("score", "score") ], [fn.sum("v", "wv"), fn.sum("score", "z")], )
def propagate_attention(self, g, eids, per_head=False): # Compute attention score if per_head: for i in range(0, len(per_head)): # This sends in the edges per head. score_key = 'score{}'.format(i) g.apply_edges(src_dot_dst('k', 'q', score_key, i), per_head[i]) g.apply_edges(scaled_exp(score_key, np.sqrt(self.d_k)), per_head[i]) # Send weighted values to target nodes g.send_and_recv(per_head[i], [ fn.src_mul_edge('v', score_key, 'v'), fn.copy_edge(score_key, score_key) ], [fn.sum('v', 'wv'), fn.sum(score_key, 'z')]) else: g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) # Send weighted values to target nodes g.send_and_recv(eids, [ fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score') ], [fn.sum('v', 'wv'), fn.sum('score', 'z')])
def test_src_mul_edge(): # src_mul_edge with all fields g = generate_graph() g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) g.register_reduce_func(reducer_both) # test with update_all g.update_all() assert F.allclose(g.ndata.pop('out'), F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) # test with send and then recv g.send() g.recv() assert F.allclose(g.ndata.pop('out'), F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))