コード例 #1
0
ファイル: test_subgraph.py プロジェクト: yuk12/dgl
 def _check_typed_subgraph1(g, sg):
     assert g.idtype == sg.idtype
     assert g.device == sg.device
     assert set(sg.ntypes) == {'user', 'game'}
     assert set(sg.etypes) == {'follows', 'plays', 'wishes'}
     for ntype in sg.ntypes:
         assert sg.number_of_nodes(ntype) == g.number_of_nodes(ntype)
     for etype in sg.etypes:
         src_sg, dst_sg = sg.all_edges(etype=etype, order='eid')
         src_g, dst_g = g.all_edges(etype=etype, order='eid')
         assert F.array_equal(src_sg, src_g)
         assert F.array_equal(dst_sg, dst_g)
     assert F.array_equal(sg.nodes['user'].data['h'],
                          g.nodes['user'].data['h'])
     assert F.array_equal(sg.edges['follows'].data['h'],
                          g.edges['follows'].data['h'])
     g.nodes['user'].data['h'] = F.scatter_row(g.nodes['user'].data['h'],
                                               F.tensor([2]), F.randn(
                                                   (1, 5)))
     g.edges['follows'].data['h'] = F.scatter_row(
         g.edges['follows'].data['h'], F.tensor([1]), F.randn((1, 4)))
     assert F.array_equal(sg.nodes['user'].data['h'],
                          g.nodes['user'].data['h'])
     assert F.array_equal(sg.edges['follows'].data['h'],
                          g.edges['follows'].data['h'])
コード例 #2
0
def test_reorder_nodes():
    g = dgl.DGLGraph(create_large_graph_index(1000), readonly=True)
    new_nids = np.random.permutation(g.number_of_nodes())
    # TODO(zhengda) we need to test both CSR and COO.
    new_g = dgl.transform.reorder_nodes(g, new_nids)
    new_in_deg = new_g.in_degrees()
    new_out_deg = new_g.out_degrees()
    in_deg = g.in_degrees()
    out_deg = g.out_degrees()
    new_in_deg1 = F.scatter_row(in_deg, F.tensor(new_nids), in_deg)
    new_out_deg1 = F.scatter_row(out_deg, F.tensor(new_nids), out_deg)
    assert np.all(F.asnumpy(new_in_deg == new_in_deg1))
    assert np.all(F.asnumpy(new_out_deg == new_out_deg1))
    orig_ids = F.asnumpy(new_g.ndata['orig_id'])
    for nid in range(g.number_of_nodes()):
        neighs = F.asnumpy(g.successors(nid))
        new_neighs1 = new_nids[neighs]
        new_nid = new_nids[nid]
        new_neighs2 = new_g.successors(new_nid)
        assert np.all(np.sort(new_neighs1) == np.sort(F.asnumpy(new_neighs2)))

    for nid in range(new_g.number_of_nodes()):
        neighs = F.asnumpy(new_g.successors(nid))
        old_neighs1 = orig_ids[neighs]
        old_nid = orig_ids[nid]
        old_neighs2 = g.successors(old_nid)
        assert np.all(np.sort(old_neighs1) == np.sort(F.asnumpy(old_neighs2)))

        neighs = F.asnumpy(new_g.predecessors(nid))
        old_neighs1 = orig_ids[neighs]
        old_nid = orig_ids[nid]
        old_neighs2 = g.predecessors(old_nid)
        assert np.all(np.sort(old_neighs1) == np.sort(F.asnumpy(old_neighs2)))
コード例 #3
0
ファイル: test_new_kvstore.py プロジェクト: yw3388/dgl
def udf_push(target, name, id_tensor, data_tensor):
    target[name] = F.scatter_row(target[name], id_tensor,
                                 data_tensor * data_tensor)