Exemple #1
0
 def __sample__(self):
     sampled_roots = th.randint(0, self.train_g.num_nodes(),
                                (self.num_roots, ))
     traces, types = random_walk(self.train_g,
                                 nodes=sampled_roots,
                                 length=self.length)
     sampled_nodes, _, _, _ = pack_traces(traces, types)
     sampled_nodes = sampled_nodes.unique()
     return sampled_nodes.numpy()
Exemple #2
0
def metapath_random_walk(g, metapaths, num_walks, walk_length, output_file):
    """基于元路径的随机游走

    :param g: DGLGraph 异构图
    :param metapaths: Dict[str, List[str]] 顶点类型到元路径的映射
    :param num_walks: int 每个顶点游走次数
    :param walk_length: int 元路径重复次数
    :param output_file: str 输出文件名
    :return:
    """
    f = open(output_file, 'w')
    for ntype, metapath in metapaths.items():
        print(ntype)
        loader = DataLoader(torch.arange(g.num_nodes(ntype)), batch_size=200)
        for b in tqdm(loader, ncols=80):
            nodes = torch.repeat_interleave(b, num_walks)
            traces, types = random_walk(g, nodes, metapath=metapath * walk_length)
            f.writelines([trace2name(g, trace, types) + '\n' for trace in traces])
    f.close()
def main():
    parser = argparse.ArgumentParser(
        description='Metapath-based Random Walk for metapath2vec')
    parser.add_argument('--num-walks',
                        type=int,
                        default=1000,
                        help='number of walks for each node')
    parser.add_argument('--walk-length',
                        type=int,
                        default=100,
                        help='times to repeat metapath')
    parser.add_argument('output_file', help='output filename')
    args = parser.parse_args()

    data = AMinerCSDataset()
    g = data[0]

    ca = metapath_adj(g, ['cp', 'pa'])
    ca_c, ca_a = ca.nonzero()
    cag = dgl.heterograph(
        {
            ('conf', 'ca', 'author'): (ca_c, ca_a),
            ('author', 'ac', 'conf'): (ca_a, ca_c)
        }, {
            'conf': ca.shape[0],
            'author': ca.shape[1]
        })
    cag.edges['ca'].data['p'] = cag.edges['ac'].data['p'] = torch.from_numpy(
        ca.data).float()

    metapath = ['ca', 'ac']  # metapath = CAC, metapath*2 = CACAC
    f = open(args.output_file, 'w')
    for cid in trange(cag.num_nodes('conf'), ncols=80):
        traces, _ = random_walk(cag, [cid] * args.num_walks,
                                metapath=metapath * args.walk_length,
                                prob='p')
        f.writelines([
            trace2name(data.author_names, data.conf_names, t) + '\n'
            for t in traces
        ])
    f.close()
 def sample(g, nodes, length, queue: mp.Queue):
     from dgl.sampling import random_walk
     ret = random_walk(g, nodes, length=length)[0].numpy().tolist()
     queue.put(ret)
     return None