Esempio n. 1
0
def doc_subgraph(G, doc_ids):
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
    _, _, (block, ) = sampler.sample(G.reverse(),
                                     {'doc': torch.as_tensor(doc_ids)})
    B = dgl.DGLHeteroGraph(block._graph, ['_', 'word', 'doc', '_'],
                           block.etypes).reverse()
    B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID']
    return B
Esempio n. 2
0
def create_test_heterograph():
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    mg = nx.MultiDiGraph([('user', 'user', 'follows'),
                          ('user', 'game', 'plays'),
                          ('user', 'game', 'wishes'),
                          ('developer', 'game', 'develops')])

    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    g = dgl.DGLHeteroGraph((mg, {
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'):
        plays_spmat,
        ('user', 'wishes', 'game'): [(0, 1), (2, 0)],
        ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
    }))

    return g
Esempio n. 3
0
def test_pickling_heterograph_index_compatibility():
    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    wishes_nx = nx.DiGraph()
    wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
    wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
    wishes_nx.add_edge('u0', 'g1', id=0)
    wishes_nx.add_edge('u2', 'g0', id=1)

    follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
    wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
    develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

    with open("tests/compute/hetero_pickle_old.pkl", "rb") as f:
        gi = pickle.load(f)
        f.close()
    new_g = dgl.DGLHeteroGraph(gi, g.ntypes, g.etypes)
    _assert_is_identical_hetero(g, new_g)