def test_segment_reduce(reducer): ctx = F.ctx() value = F.tensor(np.random.rand(10, 5)) v1 = F.attach_grad(F.clone(value)) v2 = F.attach_grad(F.clone(value)) seglen = F.tensor([2, 3, 0, 4, 1, 0, 0]) u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx) v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0) num_nodes = {'_U': len(u), '_V': len(seglen)} g = dgl.convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes) with F.record_grad(): rst1 = gspmm(g, 'copy_lhs', reducer, v1, None) if reducer in ['max', 'min']: rst1 = F.replace_inf_with_zero(rst1) F.backward(F.reduce_sum(rst1)) grad1 = F.grad(v1) with F.record_grad(): rst2 = segment_reduce(seglen, v2, reducer=reducer) F.backward(F.reduce_sum(rst2)) assert F.allclose(rst1, rst2) print('forward passed') grad2 = F.grad(v2) assert F.allclose(grad1, grad2) print('backward passed')
def test_spmm(idtype, g, shp, msg, reducer): g = g.astype(idtype).to(F.ctx()) print(g) print(g.idtype) hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(), ) + shp[0])) + 1) he = F.tensor(np.random.rand(*((g.number_of_edges(), ) + shp[1])) + 1) print('u shape: {}, e shape: {}'.format(F.shape(hu), F.shape(he))) g.srcdata['x'] = F.attach_grad(F.clone(hu)) g.edata['w'] = F.attach_grad(F.clone(he)) print('SpMM(message func: {}, reduce func: {})'.format(msg, reducer)) u = F.attach_grad(F.clone(hu)) e = F.attach_grad(F.clone(he)) with F.record_grad(): v = gspmm(g, msg, reducer, u, e) if reducer in ['max', 'min']: v = F.replace_inf_with_zero(v) if g.number_of_edges() > 0: F.backward(F.reduce_sum(v)) if msg != 'copy_rhs': grad_u = F.grad(u) if msg != 'copy_lhs': grad_e = F.grad(e) with F.record_grad(): g.update_all(udf_msg[msg], udf_reduce[reducer]) if g.number_of_edges() > 0: v1 = g.dstdata['v'] assert F.allclose(v, v1) print('forward passed') F.backward(F.reduce_sum(v1)) if msg != 'copy_rhs': if reducer in ['min', 'max']: # there might be some numerical errors rate = F.reduce_sum(F.abs(F.grad(g.srcdata['x']) - grad_u)) /\ F.reduce_sum(F.abs(grad_u)) assert F.as_scalar(rate) < 1e-2, rate else: assert F.allclose(F.grad(g.srcdata['x']), grad_u) if msg != 'copy_lhs': if reducer in ['min', 'max']: rate = F.reduce_sum(F.abs(F.grad(g.edata['w']) - grad_e)) /\ F.reduce_sum(F.abs(grad_e)) assert F.as_scalar(rate) < 1e-2, rate else: assert F.allclose(F.grad(g.edata['w']), grad_e) print('backward passed') g.srcdata.pop('x') g.edata.pop('w') if 'v' in g.dstdata: g.dstdata.pop('v')