def test_prop_edges_dfs(): g = dgl.DGLGraph(nx.path_graph(5)) g.register_message_func(mfunc) g.register_reduce_func(rfunc) g.ndata['x'] = mx.nd.ones(shape=(5, 2)) dgl.prop_edges_dfs(g, 0) # snr using dfs results in a cumsum assert np.allclose( g.ndata['x'].asnumpy(), np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]])) g.ndata['x'] = mx.nd.ones(shape=(5, 2)) dgl.prop_edges_dfs(g, 0, has_reverse_edge=True) # result is cumsum[i] + cumsum[i-1] assert np.allclose( g.ndata['x'].asnumpy(), np.array([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]])) g.ndata['x'] = mx.nd.ones(shape=(5, 2)) dgl.prop_edges_dfs(g, 0, has_nontree_edge=True) # result is cumsum[i] + cumsum[i+1] assert np.allclose( g.ndata['x'].asnumpy(), np.array([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
def forward( self, graph: dgl.DGLGraph ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: graph.ndata['x'] = self.dropout(graph.ndata['x']) root_indexes = graph.ndata['x'].new_tensor(get_root_indexes( graph.batch_num_nodes), dtype=torch.long, requires_grad=False) graph.ndata['h'] = graph.ndata['x'].new_zeros( (graph.number_of_nodes(), self.h_enc)) graph.ndata['c'] = graph.ndata['x'].new_zeros( (graph.number_of_nodes(), self.h_enc)) graph.ndata['h'][root_indexes], graph.ndata['c'][root_indexes] = \ self.lstm(graph.ndata['x'][root_indexes]) dgl.prop_edges_dfs(graph, root_indexes, True, message_func=self.message_func, reduce_func=self.reduce_func) return graph.ndata.pop('h'), graph.ndata.pop('c')
def test_prop_edges_dfs(idtype): g = dgl.graph(nx.path_graph(5), idtype=idtype, device=F.ctx()) g.ndata['x'] = F.ones((5, 2)) dgl.prop_edges_dfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None) # snr using dfs results in a cumsum assert F.allclose(g.ndata['x'], F.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]])) g.ndata['x'] = F.ones((5, 2)) dgl.prop_edges_dfs(g, 0, has_reverse_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None) # result is cumsum[i] + cumsum[i-1] assert F.allclose(g.ndata['x'], F.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]])) g.ndata['x'] = F.ones((5, 2)) dgl.prop_edges_dfs(g, 0, has_nontree_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None) # result is cumsum[i] + cumsum[i+1] assert F.allclose(g.ndata['x'], F.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))