def test_recv_0deg_newfld(): # test recv with 0deg nodes; the reducer also creates a new field g = dgl.graph([(0,1)]) def _message(edges): return {'m' : edges.src['h']} def _reduce(nodes): return {'h1' : nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h1' : nodes.data['h1'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + F.zeros(shape, dtype=dtype, ctx=ctx) # test#1: recv both 0deg and non-0deg nodes old = F.randn((2, 5)) g.set_n_initializer(_init2, 'h1') g.ndata['h'] = old g.send((0, 1), _message) g.recv([0, 1], _reduce, _apply) new = g.ndata.pop('h1') # 0deg check: initialized with the func and got applied assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32)) # non-0deg check assert F.allclose(new[1], F.sum(old, 0) * 2) # test#2: recv only 0deg node old = F.randn((2, 5)) g.ndata['h'] = old g.ndata['h1'] = F.full((2, 5), -1, F.int64) # this is necessary g.send((0, 1), _message) g.recv(0, _reduce, _apply) new = g.ndata.pop('h1') # 0deg check: fallback to apply assert F.allclose(new[0], F.full_1d(5, -2, F.int64)) # non-0deg check: not changed assert F.allclose(new[1], F.full_1d(5, -1, F.int64))
def test_multi_recv_0deg(): # test recv with 0deg nodes; g = DGLGraph() def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h': nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h': nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + F.zeros(shape, dtype=dtype, ctx=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(_init2) g.add_nodes(2) g.add_edge(0, 1) # recv both 0deg and non-0deg nodes old = F.randn((2, 5)) g.ndata['h'] = old g.send((0, 1)) g.recv([0, 1]) new = g.ndata['h'] # 0deg check: initialized with the func and got applied assert F.allclose(new[0], F.full((5, ), 4, F.float32)) # non-0deg check assert F.allclose(new[1], F.sum(old, 0) * 2) # recv again on zero degree node g.recv([0]) assert F.allclose(g.nodes[0].data['h'], F.full((5, ), 8, F.float32)) # recv again on node with no incoming message g.recv([1]) assert F.allclose(g.nodes[1].data['h'], F.sum(old, 0) * 4)
def _test(feat_scale): in_feat = 16 * feat_scale out_feat = 8 * feat_scale print("in/out feat", in_feat, out_feat) E_per_rel = F.copy_to( F.tensor([ 50, 100, 20, 284, 89, 10, 82, 9200, 10, 20, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30, 100, 1280, 20, 284, 89, 1000, 82, 92, 10, 2000, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30 ]), F.cpu()) E_per_rel *= n_edge_scale num_rel = len(E_per_rel) print('num_rel', num_rel) W_per_len = F.copy_to( F.full((num_rel, ), in_feat, dtype=F.dtype(E_per_rel)), F.cpu()) H_arr = [] W_arr = [] Out_arr = [] Out_grad_arr = [] for eid in range(num_rel): H_arr.append(F.randn((E_per_rel[eid], in_feat))) W_arr.append(F.randn((in_feat, out_feat))) Out_arr.append(F.zeros((E_per_rel[eid], out_feat))) Out_grad_arr.append(F.ones((E_per_rel[eid], out_feat))) H = F.cat([h for h in H_arr], 0) W = F.cat([w for w in W_arr], 0) W_3D = W.reshape(num_rel, in_feat, out_feat) Out = F.cat([out for out in Out_arr], 0) Out_grad = F.cat([o for o in Out_grad_arr], 0) print('H.shape', H.shape) print('W.shape', W.shape) print('W_3D.shape', W_3D.shape) print('Out.shape', Out.shape) etype_arr = [] for eid in range(num_rel): etype_arr.append( F.full((E_per_rel[eid], ), eid, dtype=F.dtype(E_per_rel))) etypes = F.cat([etype for etype in etype_arr], 0) ################################################################# # low-mem version using PyTorch operator ################################################################# # forward pass out = [] for i in range(len(E_per_rel)): Hi = H_arr[i] Wi = W_arr[i] out.append(F.matmul(Hi, Wi)) out_low_mem = F.cat(out, 0) # backward pass H_grad = [] W_grad = [] for i in range(len(E_per_rel)): Hi = H_arr[i] Wi = W_arr[i] Out_gradi = Out_grad_arr[i] H_grad.append(F.matmul(Out_gradi, Wi.transpose(0, 1))) W_grad.append(F.matmul(Hi.transpose(0, 1), Out_gradi)) Hgrad_low_mem = F.cat(H_grad, 0) Wgrad_low_mem = F.cat(W_grad, 0) Wgrad_low_mem = Wgrad_low_mem.reshape(num_rel, in_feat, out_feat) ################################################################# # gather_mm where H sorted according to etype ################################################################# seglen_A = E_per_rel F.attach_grad(H) F.attach_grad(W_3D) with F.record_grad(): out_gmm_sorted = dgl.ops.segment_mm(H, W_3D, seglen_A) F.backward(F.reduce_sum(out_gmm_sorted)) Hgrad_gmm_sorted = H.grad Wgrad_gmm_sorted = W_3D.grad ################################################################# # gather_mm where H is not sorted (backward not supported yet) ################################################################# F.attach_grad(H) F.attach_grad(W_3D) with F.record_grad(): out_gmm_unsorted = dgl.ops.gather_mm(H, W_3D, idx_rhs=etypes) F.backward(F.reduce_sum(out_gmm_unsorted)) Hgrad_gmm_unsorted = H.grad Wgrad_gmm_unsorted = W_3D.grad # correctness check assert F.allclose(out_low_mem, out_gmm_sorted, atol=1e-3, rtol=1e-3) assert F.allclose(Hgrad_low_mem, Hgrad_gmm_sorted, atol=1e-3, rtol=1e-3) assert F.allclose(Wgrad_low_mem, Wgrad_gmm_sorted, atol=1e-3, rtol=1e-3) assert F.allclose(out_low_mem, out_gmm_unsorted, atol=1e-3, rtol=1e-3) assert F.allclose(Hgrad_low_mem, Hgrad_gmm_unsorted, atol=1e-3, rtol=1e-3) assert F.allclose(Wgrad_low_mem, Wgrad_gmm_unsorted, atol=1e-3, rtol=1e-3)