示例#1
0
def check_reduce_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 5):
            row_idx.append(i)
            col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(5)
        g.add_edge(1, 0)
        g.add_edge(2, 0)
        g.add_edge(3, 0)
        g.add_edge(4, 0)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}

    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)

    g.set_n_initializer(_init2, 'h')
    old_repr = mx.nd.random.normal(shape=(5, 5))
    g.set_n_repr({'h': old_repr})
    g.update_all(_message, _reduce)
    new_repr = g.ndata['h']

    assert np.allclose(new_repr[1:].asnumpy(), 2 + np.zeros((4, 5)))
    assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
示例#2
0
def test_update_all_0deg():
    # test#1
    g = DGLGraph()
    g.add_nodes(5)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.data['h'] + mx.nd.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h': nodes.data['h'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)

    g.set_n_initializer(_init2, 'h')
    old_repr = mx.nd.random.normal(shape=(5, 5))
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
    # initializer and applied with UDF.
    assert np.allclose(new_repr[1:].asnumpy(), 2 * (2 + np.zeros((4, 5))))
    assert np.allclose(new_repr[0].asnumpy(),
                       2 * mx.nd.sum(old_repr, 0).asnumpy())

    # test#2: graph with no edge
    g = DGLGraph()
    g.add_nodes(5)
    g.set_n_initializer(_init2, 'h')
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # should fallback to apply
    assert np.allclose(new_repr.asnumpy(), 2 * old_repr.asnumpy())