def __sample__(self): sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots, )) traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) sampled_nodes, _, _, _ = pack_traces(traces, types) sampled_nodes = sampled_nodes.unique() return sampled_nodes.numpy()
def metapath_random_walk(g, metapaths, num_walks, walk_length, output_file): """基于元路径的随机游走 :param g: DGLGraph 异构图 :param metapaths: Dict[str, List[str]] 顶点类型到元路径的映射 :param num_walks: int 每个顶点游走次数 :param walk_length: int 元路径重复次数 :param output_file: str 输出文件名 :return: """ f = open(output_file, 'w') for ntype, metapath in metapaths.items(): print(ntype) loader = DataLoader(torch.arange(g.num_nodes(ntype)), batch_size=200) for b in tqdm(loader, ncols=80): nodes = torch.repeat_interleave(b, num_walks) traces, types = random_walk(g, nodes, metapath=metapath * walk_length) f.writelines([trace2name(g, trace, types) + '\n' for trace in traces]) f.close()
def main(): parser = argparse.ArgumentParser( description='Metapath-based Random Walk for metapath2vec') parser.add_argument('--num-walks', type=int, default=1000, help='number of walks for each node') parser.add_argument('--walk-length', type=int, default=100, help='times to repeat metapath') parser.add_argument('output_file', help='output filename') args = parser.parse_args() data = AMinerCSDataset() g = data[0] ca = metapath_adj(g, ['cp', 'pa']) ca_c, ca_a = ca.nonzero() cag = dgl.heterograph( { ('conf', 'ca', 'author'): (ca_c, ca_a), ('author', 'ac', 'conf'): (ca_a, ca_c) }, { 'conf': ca.shape[0], 'author': ca.shape[1] }) cag.edges['ca'].data['p'] = cag.edges['ac'].data['p'] = torch.from_numpy( ca.data).float() metapath = ['ca', 'ac'] # metapath = CAC, metapath*2 = CACAC f = open(args.output_file, 'w') for cid in trange(cag.num_nodes('conf'), ncols=80): traces, _ = random_walk(cag, [cid] * args.num_walks, metapath=metapath * args.walk_length, prob='p') f.writelines([ trace2name(data.author_names, data.conf_names, t) + '\n' for t in traces ]) f.close()
def sample(g, nodes, length, queue: mp.Queue): from dgl.sampling import random_walk ret = random_walk(g, nodes, length=length)[0].numpy().tolist() queue.put(ret) return None