Ejemplo n.º 1
0
def test_pull_0deg():
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}

    def _apply(nodes):
        return {'h': nodes.data['h'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: pull both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5, ), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: pull only 0deg node
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull(0)
    new = g.ndata.pop('h')
    # 0deg check: fallback to apply
    assert U.allclose(new[0], 2 * old[0])
    # non-0deg check: not touched
    assert U.allclose(new[1], old[1])
Ejemplo n.º 2
0
def check_pull_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        row_idx.append(0)
        col_idx.append(1)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(2)
        g.add_edge(0, 1)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.mailbox['m'].sum(1)}

    def _apply(nodes):
        return {'h': nodes.data['h'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)

    g.set_n_initializer(_init2, 'h')
    old_repr = mx.nd.random.normal(shape=(2, 5))

    # test#1: pull only 0-deg node
    g.ndata['h'] = old_repr
    g.pull(0, _message, _reduce, _apply)
    new_repr = g.ndata['h']
    # 0deg check: equal to apply_nodes
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy() * 2)
    # non-0deg check: untouched
    assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())

    # test#2: pull only non-deg node
    g.ndata['h'] = old_repr
    g.pull(1, _message, _reduce, _apply)
    new_repr = g.ndata['h']
    # 0deg check: untouched
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
    # non-0deg check: recved node0 and got applied
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)

    # test#3: pull only both nodes
    g.ndata['h'] = old_repr
    g.pull([0, 1], _message, _reduce, _apply)
    new_repr = g.ndata['h']
    # 0deg check: init and applied
    t = mx.nd.zeros(shape=(2, 5)) + 4
    assert np.allclose(new_repr[0].asnumpy(), t.asnumpy())
    # non-0deg check: recv node0 and applied
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)