예제 #1
0
def track_time(batch_size, feat_size, readout_op, type):
    device = utils.get_bench_device()
    ds = dgl.data.QM7bDataset()
    # prepare graph
    graphs = ds[0:batch_size][0]

    g = dgl.batch(graphs).to(device)
    if type == 'node':
        g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device)
        for i in range(10):
            out = dgl.readout_nodes(g, 'h', op=readout_op)
        with utils.Timer() as t:
            for i in range(50):
                out = dgl.readout_nodes(g, 'h', op=readout_op)
    elif type == 'edge':
        g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device)
        for i in range(10):
            out = dgl.readout_edges(g, 'h', op=readout_op)
        with utils.Timer() as t:
            for i in range(50):
                out = dgl.readout_edges(g, 'h', op=readout_op)
    else:
        raise Exception("Unknown type")

    return t.elapsed_secs / 50
예제 #2
0
def test_weighted_reduce_readout(g, idtype, reducer):
    g = g.astype(idtype).to(F.ctx())
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.ndata['w'] = F.randn((g.number_of_nodes(), 1))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
    g.edata['w'] = F.randn((g.number_of_edges(), 1))

    # Test.1: node readout
    x = dgl.readout_nodes(g, 'h', 'w', op=reducer)
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = dgl.readout_nodes(sg, 'h', 'w', op=reducer)
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    x = getattr(dgl, '{}_nodes'.format(reducer))(g, 'h', 'w')
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = getattr(dgl, '{}_nodes'.format(reducer))(sg, 'h', 'w')
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    # Test.2: edge readout
    x = dgl.readout_edges(g, 'h', 'w', op=reducer)
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = dgl.readout_edges(sg, 'h', 'w', op=reducer)
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    x = getattr(dgl, '{}_edges'.format(reducer))(g, 'h', 'w')
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = getattr(dgl, '{}_edges'.format(reducer))(sg, 'h', 'w')
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))
예제 #3
0
def track_time(batch_size, feat_size, readout_op, type):
    device = utils.get_bench_device()
    ds = dgl.data.QM7bDataset()
    # prepare graph
    graphs = ds[0:batch_size][0]

    g = dgl.batch(graphs).to(device)
    if type == 'node':
        g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device)
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_nodes(g, 'h', readout_op)
        t1 = time.time()
    elif type == 'edge':
        g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device)
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_edges(g, 'h', readout_op)
        t1 = time.time()
    else:
        raise Exception("Unknown type")

    return (t1 - t0) / 10