Esempio n. 1
0
def test_prop_edges_dfs():
    g = dgl.DGLGraph(nx.path_graph(5))
    g.register_message_func(mfunc)
    g.register_reduce_func(rfunc)

    g.ndata['x'] = mx.nd.ones(shape=(5, 2))
    dgl.prop_edges_dfs(g, 0)
    # snr using dfs results in a cumsum
    assert np.allclose(
        g.ndata['x'].asnumpy(),
        np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))

    g.ndata['x'] = mx.nd.ones(shape=(5, 2))
    dgl.prop_edges_dfs(g, 0, has_reverse_edge=True)
    # result is cumsum[i] + cumsum[i-1]
    assert np.allclose(
        g.ndata['x'].asnumpy(),
        np.array([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))

    g.ndata['x'] = mx.nd.ones(shape=(5, 2))
    dgl.prop_edges_dfs(g, 0, has_nontree_edge=True)
    # result is cumsum[i] + cumsum[i+1]
    assert np.allclose(
        g.ndata['x'].asnumpy(),
        np.array([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
    def forward(
            self, graph: dgl.DGLGraph
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        graph.ndata['x'] = self.dropout(graph.ndata['x'])
        root_indexes = graph.ndata['x'].new_tensor(get_root_indexes(
            graph.batch_num_nodes),
                                                   dtype=torch.long,
                                                   requires_grad=False)
        graph.ndata['h'] = graph.ndata['x'].new_zeros(
            (graph.number_of_nodes(), self.h_enc))
        graph.ndata['c'] = graph.ndata['x'].new_zeros(
            (graph.number_of_nodes(), self.h_enc))
        graph.ndata['h'][root_indexes], graph.ndata['c'][root_indexes] = \
            self.lstm(graph.ndata['x'][root_indexes])

        dgl.prop_edges_dfs(graph,
                           root_indexes,
                           True,
                           message_func=self.message_func,
                           reduce_func=self.reduce_func)

        return graph.ndata.pop('h'), graph.ndata.pop('c')
Esempio n. 3
0
def test_prop_edges_dfs(idtype):
    g = dgl.graph(nx.path_graph(5), idtype=idtype, device=F.ctx())
    g.ndata['x'] = F.ones((5, 2))
    dgl.prop_edges_dfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
    # snr using dfs results in a cumsum
    assert F.allclose(g.ndata['x'],
            F.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))

    g.ndata['x'] = F.ones((5, 2))
    dgl.prop_edges_dfs(g, 0, has_reverse_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
    # result is cumsum[i] + cumsum[i-1]
    assert F.allclose(g.ndata['x'],
            F.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))

    g.ndata['x'] = F.ones((5, 2))
    dgl.prop_edges_dfs(g, 0, has_nontree_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
    # result is cumsum[i] + cumsum[i+1]
    assert F.allclose(g.ndata['x'],
            F.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))