def track_time(batch_size, feat_size, readout_op, type): device = utils.get_bench_device() ds = dgl.data.QM7bDataset() # prepare graph graphs = ds[0:batch_size][0] g = dgl.batch(graphs).to(device) if type == 'node': g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device) for i in range(10): out = dgl.readout_nodes(g, 'h', op=readout_op) with utils.Timer() as t: for i in range(50): out = dgl.readout_nodes(g, 'h', op=readout_op) elif type == 'edge': g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device) for i in range(10): out = dgl.readout_edges(g, 'h', op=readout_op) with utils.Timer() as t: for i in range(50): out = dgl.readout_edges(g, 'h', op=readout_op) else: raise Exception("Unknown type") return t.elapsed_secs / 50
def test_weighted_reduce_readout(g, idtype, reducer): g = g.astype(idtype).to(F.ctx()) g.ndata['h'] = F.randn((g.number_of_nodes(), 3)) g.ndata['w'] = F.randn((g.number_of_nodes(), 1)) g.edata['h'] = F.randn((g.number_of_edges(), 2)) g.edata['w'] = F.randn((g.number_of_edges(), 1)) # Test.1: node readout x = dgl.readout_nodes(g, 'h', 'w', op=reducer) # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = dgl.readout_nodes(sg, 'h', 'w', op=reducer) subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) x = getattr(dgl, '{}_nodes'.format(reducer))(g, 'h', 'w') # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = getattr(dgl, '{}_nodes'.format(reducer))(sg, 'h', 'w') subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) # Test.2: edge readout x = dgl.readout_edges(g, 'h', 'w', op=reducer) # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = dgl.readout_edges(sg, 'h', 'w', op=reducer) subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) x = getattr(dgl, '{}_edges'.format(reducer))(g, 'h', 'w') # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = getattr(dgl, '{}_edges'.format(reducer))(sg, 'h', 'w') subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0))
def track_time(batch_size, feat_size, readout_op, type): device = utils.get_bench_device() ds = dgl.data.QM7bDataset() # prepare graph graphs = ds[0:batch_size][0] g = dgl.batch(graphs).to(device) if type == 'node': g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device) t0 = time.time() for i in range(10): out = dgl.readout_nodes(g, 'h', readout_op) t1 = time.time() elif type == 'edge': g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device) t0 = time.time() for i in range(10): out = dgl.readout_edges(g, 'h', readout_op) t1 = time.time() else: raise Exception("Unknown type") return (t1 - t0) / 10