def test_pull_0deg(): g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)} def _apply(nodes): return {'h': nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + th.zeros(shape, dtype=dtype, device=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(_init2, 'h') # test#1: pull both 0deg and non-0deg nodes old = th.randn((2, 5)) g.ndata['h'] = old g.pull([0, 1]) new = g.ndata.pop('h') # 0deg check: initialized with the func and got applied assert U.allclose(new[0], th.full((5, ), 4)) # non-0deg check assert U.allclose(new[1], th.sum(old, 0) * 2) # test#2: pull only 0deg node old = th.randn((2, 5)) g.ndata['h'] = old g.pull(0) new = g.ndata.pop('h') # 0deg check: fallback to apply assert U.allclose(new[0], 2 * old[0]) # non-0deg check: not touched assert U.allclose(new[1], old[1])
def check_pull_0deg(readonly): if readonly: row_idx = [] col_idx = [] row_idx.append(0) col_idx.append(1) ones = np.ones(shape=(len(row_idx))) csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2)) g = DGLGraph(csr, readonly=True) else: g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h': nodes.mailbox['m'].sum(1)} def _apply(nodes): return {'h': nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx) g.set_n_initializer(_init2, 'h') old_repr = mx.nd.random.normal(shape=(2, 5)) # test#1: pull only 0-deg node g.ndata['h'] = old_repr g.pull(0, _message, _reduce, _apply) new_repr = g.ndata['h'] # 0deg check: equal to apply_nodes assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy() * 2) # non-0deg check: untouched assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy()) # test#2: pull only non-deg node g.ndata['h'] = old_repr g.pull(1, _message, _reduce, _apply) new_repr = g.ndata['h'] # 0deg check: untouched assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) # non-0deg check: recved node0 and got applied assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2) # test#3: pull only both nodes g.ndata['h'] = old_repr g.pull([0, 1], _message, _reduce, _apply) new_repr = g.ndata['h'] # 0deg check: init and applied t = mx.nd.zeros(shape=(2, 5)) + 4 assert np.allclose(new_repr[0].asnumpy(), t.asnumpy()) # non-0deg check: recv node0 and applied assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)