def test_safe_pool() -> None: seq = list(range(10000)) target = list(map(sqr, seq)) # sequential with mp.Pool(1) as pool: pool_type = type(pool) def file_obj(): nonlocal call_count call_count += 1 call_count = 0 with flutes.safe_pool(0, closing=[file_obj]) as pool: check_iterator(pool.imap(sqr, seq), target) assert not isinstance(pool, pool_type) assert call_count == 1 file_obj = NonCallableMagicMock() file_obj.mock_add_spec(["close"]) with flutes.safe_pool(2, closing=[file_obj], suppress_exceptions=True) as pool: result = list(pool.imap(sqr, seq)) raise ValueError # should swallow exceptions assert isinstance(pool, pool_type) assert result == target file_obj.close.assert_called_once() with pytest.raises(KeyboardInterrupt): with flutes.safe_pool(2, closing=[file_obj], suppress_exceptions=True) as pool: raise KeyboardInterrupt manager = ContextManagerMock() with flutes.safe_pool(2, closing=[manager]) as pool: raise ValueError assert manager.state == 2
def compute_scores(): # Gather word counts and compute "competency" scores for each example, for use in curriculum learning. with flutes.safe_pool(args.n_procs, state_class=CountWordsState) as pool: for _ in pool.imap_unordered( CountWordsState. count_words, # use only the source sentence map(lambda ex: ex.decompiled_code, tqdm(data, desc="Counting words")), chunksize=args.block_size): pass word_counter = Counter() for state in pool.get_states(): word_counter.update(state.counter) total_words = sum(word_counter.values()) word_scores = { w: -math.log(c / total_words) for w, c in word_counter.items() } with flutes.safe_pool(args.n_procs, state_class=ComputeScoreState, init_args=(word_scores, )) as pool: scores = list( pool.imap(ComputeScoreState.compute_score, map(lambda ex: ex.decompiled_code, tqdm(data, desc="Computing scores")), chunksize=args.block_size)) return scores
def test_gather() -> None: n = 10000 intervals = list(range(0, n + 1, 1000)) answer = set(range(n)) for n_procs in [0, 2]: with flutes.safe_pool(n_procs) as pool: assert set(pool.gather(gather_fn, zip(intervals, intervals[1:]))) == answer with flutes.safe_pool(n_procs, state_class=PoolState2) as pool: assert set( pool.gather(PoolState2.gather_fn, zip(intervals, intervals[1:]))) == answer
def main(): args = Args() print(args.to_string()) random.seed(args.random_seed) if args.pdb: flutes.register_ipython_excepthook() splits = { "train_extra": "train_extra/", "valid": "dev/", "test": "test/", "train": ".", } tranx_data_dir = Path(args.tranx_data_dir) text_data_dir = Path(args.text_data_dir) output_dir = Path(args.output_path) output_dir.mkdir(parents=True, exist_ok=True) with flutes.safe_pool(args.n_procs, state_class=State) as pool: for key, path in splits.items(): pool.broadcast(State.clear) files = [ file for file in sorted(Path(tranx_data_dir / path).iterdir()) if file.name.startswith("data") and file.suffix == ".pkl" ] for _ in tqdm(pool.imap_unordered(State.process_data, files), total=len(files)): pass canonical_tgt_map: Dict[str, str] = {} total_size = 0 for state_map in pool.get_states(): total_size += len(state_map.canonical_map) canonical_tgt_map.update(state_map.canonical_map) assert total_size == len(canonical_tgt_map) print(f"{key.capitalize()} set processed") in_path = text_data_dir / f"{key}.txt" out_path = output_dir / f"{key}.txt" not_found = set() with in_path.open("r") as fin, out_path.open("w") as fout: progress = tqdm(fin, total=total_size) for line in progress: if not line: continue example = InputData.decode(line) encoded_output = canonical_tgt_map.get( example.original_code, None) if encoded_output is None: not_found.add(example.original_code) progress.set_postfix(not_found=len(not_found)) else: del canonical_tgt_map[example.original_code] fout.write(encoded_output) fout.write("\n") print( f"{len(not_found)} not found, {len(canonical_tgt_map)} leftover" ) for encoded_output in canonical_tgt_map.values(): fout.write(encoded_output) fout.write("\n") print(f"{key.capitalize()} set written to {out_path}")
def test_stateful_pool() -> None: large_dict = {str(i): i for i in range(100000)} seq = list(map(str, range(100000))) target = sum(map(lambda x: int(x) + 1, seq)) # sequential for n_procs in [0, 2]: with flutes.safe_pool(n_procs, state_class=PoolState, init_args=(large_dict, )) as pool: result = sum( pool.imap_unordered(PoolState.convert, seq, chunksize=1000)) # See, if you had a type checker, you wouldn't be making these mistakes. with pytest.raises(ValueError, match="Bound methods of the pool state class"): _ = sum( pool.imap_unordered(PoolState({}).convert, seq, chunksize=1000)) with pytest.raises( ValueError, match="Only unbound methods of the pool state class"): _ = sum( pool.imap_unordered( PoolState2.generate, seq, chunksize=1000)) # type: ignore[arg-type] assert result == target
def test_pool_methods() -> None: seq = list(range(10000)) args = (2, ) kwds = {"coef2": 3} target = [sqr(x, *args, **kwds) for x in seq] # sequential for n_procs in [0, 2]: for state_class in [PoolState, None]: with flutes.safe_pool(n_procs, state_class=state_class, init_args=(None, )) as pool: assert pool.map(sqr, seq, args=args, kwds=kwds) == target check_iterator(pool.imap(sqr, seq, args=args, kwds=kwds), target) assert sorted( pool.imap_unordered(sqr, seq, args=args, kwds=kwds)) == target assert pool.starmap(mul, zip(seq, seq), args=args, kwds=kwds) == target assert pool.map_async(sqr, seq, args=args, kwds=kwds).get() == target assert pool.starmap_async(mul, zip(seq, seq), args=args, kwds=kwds).get() == target assert pool.apply(sqr, (10, 2), kwds=kwds) == 100 * 2 * 3 assert pool.apply_async(sqr, (10, 2), kwds=kwds).get() == 100 * 2 * 3
def test_stateful_pool_get_state() -> None: for n_procs in [0, 2]: with flutes.safe_pool(n_procs, state_class=PoolState2) as pool: intervals = list(range(0, 100 + 1, 5)) pool.starmap(PoolState2.generate, zip(intervals, intervals[1:])) states = pool.get_states() assert sorted(x for state in states for x in state.numbers) == list(range(100))
def test_ProgressBarManager() -> None: for proc in [0, 2]: # Test multiprocessing in `proc = 2` # Test coverage in `proc = 0` manager = flutes.ProgressBarManager() with flutes.safe_pool(proc, closing=[manager]) as pool: fn = functools.partial(progress_bar_fn, bar=manager.proxy) pool.map(fn, range(10))
def main(args): random.seed(args.seed) if (not args.output_dir.exists()) or args.overwrite: args.output_dir.mkdir(exist_ok=True, parents=True) else: print(f"Directory {args.output_dir} already exists.") sys.exit(0) tr_indices, vl_indices, ts_indices = make_splits_indices( 100, args.split_ratio) with (args.output_dir / "file_map.txt").open("w") as f: print("train", file=f) for i in tr_indices: print(f"{i} ", file=f, end="") print(file=f) print("valid", file=f) for i in vl_indices: print(f"{i} ", file=f, end="") print(file=f) print("test") for i in ts_indices: print(f"{i} ", file=f, end="") print(file=f) # in/out file pairs files = ([( (args.dump_dir / f"{i}.pkl"), (args.output_dir / f"train_{new_idx:02}.jsonl"), ) for new_idx, i in enumerate(tr_indices)] + [( (args.dump_dir / f"{i}.pkl"), (args.output_dir / f"valid_{new_idx:02}.jsonl"), ) for new_idx, i in enumerate(vl_indices)] + [( (args.dump_dir / f"{i}.pkl"), (args.output_dir / f"test_{new_idx:02}.jsonl"), ) for new_idx, i in enumerate(ts_indices)]) # files = sorted( # list((args.dump_dir).glob("*.pkl")), key=lambda x: int(x.with_suffix("").name) # ) total = {} with flutes.safe_pool(processes=args.njobs, state_class=Processor) as pool: for idx, _ in enumerate( pool.imap_unordered(Processor.process_pkl, files, chunksize=1)): flutes.log(f"Processed {(idx + 1)} files") states = pool.get_states() for state in states: total.update(state.results) print( f"Train: {sum(v for k, v in total.items() if k.startswith('train'))}") print( f"Valid: {sum(v for k, v in total.items() if k.startswith('valid'))}") print(f"Test: {sum(v for k, v in total.items() if k.startswith('test'))}")
def test_stateful_pool_get_state() -> None: for n_procs in [0, 2]: with flutes.safe_pool(n_procs, state_class=PoolState2) as pool: intervals = list(range(0, 100 + 1, 5)) pool.starmap(PoolState2.generate, zip(intervals, intervals[1:]), args=(1, 2)) # dummy args states = pool.get_states() assert sorted(itertools.chain.from_iterable(states)) == list( range(100)) # type: ignore[arg-type]
def test_safe_pool() -> None: seq = list(range(10000)) target = list(map(sqr, seq)) # sequential with mp.Pool(1) as pool: pool_type = type(pool) file_obj = MagicMock() with flutes.safe_pool(0, closing=[file_obj]) as pool: assert type(pool) is not pool_type check_iterator(pool.imap(sqr, seq), target) file_obj.assert_called_once() file_obj = NonCallableMagicMock() file_obj.mock_add_spec(["close"]) with flutes.safe_pool(2, closing=[file_obj]) as pool: assert type(pool) is pool_type check_iterator(pool.imap(sqr, seq), target) raise ValueError # should swallow exceptions file_obj.close.assert_called_once()
def main() -> None: if len(sys.argv) < 2: print(f"Usage: python {sys.argv[0]} [file]") sys.exit(1) path = sys.argv[1] with flutes.work_in_progress("Read file"): with open(path) as f: sentences = [] for line in f: sentences.append(line) if len(sentences) >= 100000: break with flutes.work_in_progress("Parallel"): with flutes.safe_pool(processes=4, state_class=WordCounter) as pool_stateful: for _ in pool_stateful.imap_unordered(WordCounter.count_words, sentences, chunksize=1000): pass parallel_word_counter: CounterT[str] = Counter() with flutes.work_in_progress("Get states"): states = pool_stateful.get_states() for state in states: parallel_word_counter.update(state.word_cnt) with flutes.work_in_progress("Naive parallel"): naive_word_counter: CounterT[str] = Counter() data_chunks = flutes.chunk(1000, sentences) with flutes.safe_pool(processes=4) as pool: for counter in pool.imap_unordered(count_words, data_chunks): naive_word_counter.update(counter) with flutes.work_in_progress("Sequential"): seq_word_counter: CounterT[str] = Counter() for sent in sentences: seq_word_counter.update(word.lower() for word in sent.split()) assert seq_word_counter == naive_word_counter == parallel_word_counter
def main(args) -> None: files = [(args.output_dir, args.input_dir / f"{i}.pkl") for i in range(0, 10000, 100)] total_map = {} with flutes.work_in_progress("Parallel"): with flutes.safe_pool(processes=args.njobs, state_class=Worker) as pool_stateful: for idx, _ in enumerate( pool_stateful.imap_unordered(Worker.merge, files, chunksize=1)): flutes.log(f"Processed {(idx + 1)} files")
def test_ProgressBarManager() -> None: for verbose in [False, True]: for proc in [0, 2]: # Test multiprocessing in `proc = 2` # Test coverage in `proc = 0` manager = flutes.ProgressBarManager(verbose=verbose) with flutes.safe_pool(proc, closing=[manager]) as pool: fn = functools.partial(progress_bar_fn, bar=manager.proxy) pool.map(fn, range(10)) fn = functools.partial(file_progress_bar_fn, bar=manager.proxy) pool.map(fn, range(4)) flutes.log( f"This should still show up: verbose={verbose}, proc={proc}", force_console=True)
def main(): args = Args() print(args.to_string()) if args.pdb: flutes.register_ipython_excepthook() random.seed(args.random_seed) files = [ file for file in Path(args.data_dir).iterdir() if file.suffix == ".pkl" ] manager = flutes.ProgressBarManager() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) with flutes.safe_pool(args.n_procs, state_class=PoolState, init_args=(manager.proxy, )) as pool: def _wrap_iter() -> Iterator[str]: progress = manager.proxy.new(total=len(files)) for tokens in pool.imap_unordered(PoolState.collect_tokens, files): yield from tokens progress.update(1) ident_tokens = Counter() for id_counts in pool.get_states(): ident_tokens.update(id_counts) progress = manager.proxy.new(total=sum(ident_tokens.values())) for token, count in ident_tokens.items(): yield from [token] * count progress.update(count) sampled_tokens = sample(args.sample_size, _wrap_iter()) with (output_dir / "tokens.txt").open("w") as f: random.shuffle(sampled_tokens) for token in sampled_tokens: f.write(token) f.write("\n") spm_train_args = { "input": output_dir / "tokens.txt", "model_prefix": output_dir / "vocab", "vocab_size": args.vocab_size, "split_by_whitespace": 0, # false "remove_extra_whitespaces": 0, # false # "input_sentence_size": 10 ** 8, } spm.SentencePieceTrainer.Train(" ".join( f"--{name}={str(value)}" for name, value in spm_train_args.items()))
def test_pool_methods() -> None: seq = list(range(10000)) target = list(map(sqr, seq)) # sequential for n_procs in [0, 2]: for state_class in [PoolState, None]: with flutes.safe_pool(n_procs, state_class=state_class, init_args=(None, )) as pool: check_iterator(pool.map(sqr, seq), target) check_iterator(pool.imap(sqr, seq), target) check_iterator(sorted(pool.imap_unordered(sqr, seq)), target) check_iterator(pool.starmap(operator.mul, zip(seq, seq)), target) check_iterator(pool.map_async(sqr, seq).get(), target) check_iterator( pool.starmap_async(operator.mul, zip(seq, seq)).get(), target) assert pool.apply(sqr, (10, )) == 100 assert pool.apply_async(sqr, (10, )).get() == 100
def main() -> None: if args.n_procs == 0: # Only do this on the single-threaded case. flutes.register_ipython_excepthook() flutes.log(f"Running with {args.n_procs} worker processes", "warning") # Check for/create output directories make_directory(args.output_dir) # Use RAM-backed memory for tmp if available if os.path.exists('/dev/shm'): tempfile.tempdir = '/dev/shm' flutes.set_log_file(args.log_file) write_pseudo_registry() # Obtain a list of all binaries binaries = get_binary_mapping(args.binary_mapping_cache_file) flutes.log(f"{len(binaries)} binaries to process.") file_count = 0 db = ghcc.BinaryDB() with flutes.safe_pool(args.n_procs, closing=[db]) as pool: decompile_fn: Callable[[BinaryInfo], DecompilationResult] = functools.partial( decompile, output_dir=args.output_dir, binary_dir=args.binaries_dir, timeout=args.timeout) for result in pool.imap_unordered(decompile_fn, iter_binaries(db, binaries)): file_count += 1 if result is not None: db.add_binary(result.info["repo_owner"], result.info["repo_name"], result.hash, result.status is DecompilationStatus.Success) if file_count % 100 == 0: flutes.log(f"Processed {file_count} binaries", force_console=True)
def main(args) -> None: Proc = LegacyFilter if args.legacy else Filter if args.target == "paper": processing = Proc.filter_ids_complete # gathers pid, in/out cite pids elif args.target == "citation": processing = Proc.filter_ids_text # gathers pids if args.valid_citations is not None and args.valid_citations.exists(): with args.valid_citations.open("rb") as f: d = pickle.load(f) dict_valid_citations = {k: True for _, pids in d.items() for k in pids} del d else: dict_valid_citations = {} files = [ (f, dict_valid_citations, args.min_cite, args.max_cite, args.seed) for f in list(args.input_dir.glob("*")) ] with flutes.work_in_progress("Parallel"): total_results = defaultdict(list) with flutes.safe_pool(processes=args.njobs, state_class=Proc) as pool_stateful: for idx, _ in enumerate( pool_stateful.imap_unordered(processing, files, chunksize=10) ): flutes.log(f"Processed {(idx + 1)} files") with flutes.work_in_progress("Get states"): states = pool_stateful.get_states() for state in states: total_results.update(state.results) with args.output_file.open("wb") as f: # Dict[batchnum, List[obj]] pickle.dump(total_results, f)
def main() -> None: if not ghcc.utils.verify_docker_image(verbose=True): exit(1) sys.setrecursionlimit(10000) args = Arguments() if args.pdb: flutes.register_ipython_excepthook() if args.n_procs == 0: globals()['match_functions'] = match_functions.__wrapped__ if not args.verbose: flutes.set_logging_level("quiet", console=True, file=False) flutes.set_log_file(args.log_file) flutes.log("Running with arguments:\n" + args.to_string(), force_console=True) if os.path.exists(args.temp_dir): flutes.log( f"Removing contents of temporary folder '{args.temp_dir}'...", "warning", force_console=True) ghcc.utils.run_docker_command( ["rm", "-rf", "/usr/src/*"], user=0, directory_mapping={args.temp_dir: "/usr/src"}) db = ghcc.MatchFuncDB() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) manager = flutes.ProgressBarManager( verbose=args.show_progress, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}{postfix}]") with flutes.safe_pool(args.n_procs, closing=[db, manager]) as pool: iterator, stats = iter_repos( db, args.max_repos, skip_to=args.skip_to, cache_path=args.repo_binary_info_cache_path, force_reprocess=args.force_reprocess) match_fn: Callable[[RepoInfo], Result] = functools.partial( match_functions, archive_folder=args.archive_dir, temp_folder=args.temp_dir, decompile_folder=args.decompile_dir, use_fake_libc_headers=args.use_fake_libc_headers, preprocess_timeout=args.preprocess_timeout, progress_bar=manager.proxy) repo_count = stats.repo_count func_count = stats.func_count func_without_ast_count = stats.func_without_ast_count for result in pool.imap_unordered(match_fn, iterator): if result is None: # Exception occurred. if args.exit_on_exception: flutes.log( f"Exception occurred, exiting because 'exit_on_exception' is True", "warning") break continue # Write the matched functions to disk. result: Result # type: ignore repo_dir = output_dir / result.repo_owner / result.repo_name repo_dir.mkdir(parents=True, exist_ok=True) with (repo_dir / "matched_funcs.jsonl").open("w") as f: for matched_func in result.matched_functions: f.write( json.dumps(matched_func._asdict(), separators=(',', ':')) + "\n") for sha, code in result.preprocessed_original_code.items(): with (repo_dir / f"{sha}.c").open("w") as f: pos = code.rfind(ghcc.parse.FAKE_LIBC_END_LINE) if pos != -1: code = code[(pos + len(ghcc.parse.FAKE_LIBC_END_LINE)):] f.write(code) if args.write_db: db.add_repo( result.repo_owner, result.repo_name, files_found=result.files_found, funcs_found=result.functions_found, funcs_matched=len(result.matched_functions), funcs_matched_without_ast=result.funcs_without_asts) repo_count += 1 func_count += len(result.matched_functions) func_without_ast_count += result.funcs_without_asts if repo_count % 100 == 0: flutes.log( f"Processed {repo_count} repositories, {func_count} functions matched " f"({func_without_ast_count} w/o AST)", force_console=True)
def main() -> None: if not ghcc.utils.verify_docker_image(verbose=True): exit(1) args = Arguments() if args.n_procs == 0: # Only do this on the single-threaded case. flutes.register_ipython_excepthook() flutes.set_log_file(args.log_file) flutes.set_logging_level(args.logging_level, console=True, file=False) flutes.log("Running with arguments:\n" + args.to_string(), force_console=True) if os.path.exists(args.clone_folder): flutes.log( f"Removing contents of clone folder '{args.clone_folder}'...", "warning", force_console=True) ghcc.utils.run_docker_command( ["rm", "-rf", "/usr/src/*"], user=0, directory_mapping={args.clone_folder: "/usr/src"}) flutes.log("Crawling starts...", "warning", force_console=True) db = ghcc.RepoDB() libraries: Set[str] = set() if args.record_libraries is not None and os.path.exists( args.record_libraries): with open(args.record_libraries, "r") as f: libraries = set(f.read().split()) def flush_libraries(): if args.record_libraries is not None: with open(args.record_libraries, "w") as f: f.write("\n".join(libraries)) with flutes.safe_pool(args.n_procs, closing=[db, flush_libraries]) as pool: iterator = iter_repos(db, args.repo_list_file, args.max_repos) pipeline_fn: Callable[ [RepoInfo], Optional[PipelineResult]] = functools.partial( clone_and_compile, clone_folder=args.clone_folder, binary_folder=args.binary_folder, archive_folder=args.archive_folder, recursive_clone=args.recursive_clone, clone_timeout=args.clone_timeout, compile_timeout=args.compile_timeout, force_reclone=args.force_reclone, force_recompile=args.force_recompile, docker_batch_compile=args.docker_batch_compile, max_archive_size=args.max_archive_size, compression_type=args.compression_type, record_libraries=(args.record_libraries is not None), record_metainfo=args.record_metainfo, gcc_override_flags=args.gcc_override_flags) repo_count = 0 meta_info = MetaInfo() for result in pool.imap_unordered(pipeline_fn, iterator): repo_count += 1 if repo_count % 100 == 0: flutes.log(f"Processed {repo_count} repositories", force_console=True) if result is None: continue repo_owner, repo_name = result.repo_info.repo_owner, result.repo_info.repo_name if args.write_db: if result.clone_success is not None or result.repo_info.db_result is None: # There's probably an inconsistency somewhere if we didn't clone while `db_result` is None. # To prevent more errors, just add it to the DB. repo_size = result.repo_size or -1 # a value of zero is probably also wrong clone_success = result.clone_success if result.clone_success is not None else True db.add_repo(repo_owner, repo_name, clone_success, repo_size=repo_size) flutes.log(f"Added {repo_owner}/{repo_name} to DB") if result.makefiles is not None: update_result = db.update_makefile( repo_owner, repo_name, result.makefiles, ignore_length_mismatch=True) if not update_result: flutes.log( f"Makefiles of {repo_owner}/{repo_name} not saved to DB due to Unicode encoding " f"errors", "error") if result.libraries is not None: libraries.update(result.libraries) if repo_count % 10 == 0: # flush every 10 repos flush_libraries() if args.record_metainfo: meta_info.add_repo(result) if repo_count % 100 == 0: flutes.log(repr(meta_info), force_console=True) flutes.log(repr(meta_info), force_console=True)
def main(args) -> None: pids_with_text = set() if args.valid_pids and args.valid_pids.exists(): pids_with_text = set() with args.valid_pids.open("rb") as f: data = pickle.load(f) buf = set() if args.mode == "citation": for k, d in tqdm(enumerate(data.values()), ncols=88, ascii=True): buf = buf.union( set([pid for i in d for pid in i[1] + i[2]])) if k % 500 == 0: pids_with_text = pids_with_text.union(buf) buf = set() elif args.mode == "paper": for k, d in tqdm(data.items(), ncols=88, ascii=True): buf = buf.union(set([i[0] for i in d])) if k % 500 == 0: pids_with_text = pids_with_text.union(buf) buf = set() # remaining one pids_with_text = pids_with_text.union(buf) flutes.log(f"# of valid pids to consider: {len(pids_with_text)}") if args.legacy: # glob takes more time than this? files = ((args.input_dir / f"{i}.jsonl.gz", pids_with_text) for i in range(10000)) Proc = LegacyFilter else: files = ((args.input_dir / f"pdf_parses_{i}.jsonl.gz", pids_with_text) for i in range(100)) Proc = Filter with flutes.work_in_progress("Parallel"): total_map = {} with flutes.safe_pool(processes=args.njobs, state_class=Proc) as pool_stateful: for idx, _ in enumerate( pool_stateful.imap_unordered(Proc.make_map, files, chunksize=10)): flutes.log(f"Processed {(idx + 1)} files") with flutes.work_in_progress("Get states"): states = pool_stateful.get_states() for state in states: # TODO: Incorporate incite number total_map.update(state.results) flutes.log(f"Total map size: {len(total_map)}") with args.output.open("w") as f: for k, v in total_map.items(): print(k, v[0], v[1], sep="\t", file=f) flutes.log(f"Dumped to {args.output}")
def main(): # flutes.register_ipython_excepthook() sys.setrecursionlimit(50000) args = Arguments() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) flutes.log("Dataset creation start") db = ghcc.MatchFuncDB() original_code_set: Set[str] = set() n_duplicate = 0 n_examples = 0 manager = mp.Manager() example_queue: 'mp.Queue[QueueElem]' = manager.Queue(args.queue_size) with flutes.safe_pool(args.n_procs, closing=[db]) as pool: repos = [ RepoInfo(entry['repo_owner'], entry['repo_name']) for entry in db.collection.find() if entry['funcs_matched'] > 0 ] if args.max_repos is not None: repos = repos[:args.max_repos] process_fn: Callable[[RepoInfo], None] = functools.partial(process, data_dir=args.input_dir, queue=example_queue) pool.map_async(process_fn, repos, error_callback=flutes.log_exception) end_signals = 0 progress = tqdm.tqdm(total=len(repos)) file_cnt = 0 text_data = [] def save_file(): nonlocal file_cnt, text_data # Save text & AST separately with (output_dir / f"data_{file_cnt:03d}.pkl").open("wb") as f: pickle.dump(text_data, f, protocol=PICKLE_PROTOCOL) progress.write(f"Saved part {file_cnt:03d}") text_data = [] file_cnt += 1 while end_signals < len(repos): elem = example_queue.get() if elem == END_SIGNATURE: progress.update(1) end_signals += 1 continue ex = pickle.loads(elem) original_code = ex[1] if original_code not in original_code_set: original_code_set.add(original_code) text_data.append( ex) # (decompiled, orig, var_names, repo, sha) n_examples += 1 else: n_duplicate += 1 if (n_examples + n_duplicate) % 100 == 0: progress.set_postfix( { "duplicate": n_duplicate, "examples": n_examples }, refresh=False) progress.refresh() if len(text_data) >= args.block_size: save_file() if len(text_data) > 0: save_file()
def main(): args = Args() print(args.to_string()) random.seed(args.random_seed) if args.pdb: flutes.register_ipython_excepthook() output_dir = Path(args.output_path) output_dir.mkdir(parents=True, exist_ok=True) data = flutes.LazyList(read_data(args.data_dir)) @flutes.cache(output_dir / "scores.pkl", name="competency scores") def compute_scores(): # Gather word counts and compute "competency" scores for each example, for use in curriculum learning. with flutes.safe_pool(args.n_procs, state_class=CountWordsState) as pool: for _ in pool.imap_unordered( CountWordsState. count_words, # use only the source sentence map(lambda ex: ex.decompiled_code, tqdm(data, desc="Counting words")), chunksize=args.block_size): pass word_counter = Counter() for state in pool.get_states(): word_counter.update(state.counter) total_words = sum(word_counter.values()) word_scores = { w: -math.log(c / total_words) for w, c in word_counter.items() } with flutes.safe_pool(args.n_procs, state_class=ComputeScoreState, init_args=(word_scores, )) as pool: scores = list( pool.imap(ComputeScoreState.compute_score, map(lambda ex: ex.decompiled_code, tqdm(data, desc="Computing scores")), chunksize=args.block_size)) return scores scores = compute_scores() # Generate data splits. test_size = args.test_split_size or int( len(data) * args.test_split_portion) # Dev/Test set: Repositories excluded in training set. data_by_repo = defaultdict(list) for idx, ex in enumerate(data): data_by_repo[ex.repo].append(idx) repo_names = list(data_by_repo.keys()) def create_excluded_split(target_size: int, max_repos: int, extra_train_portion: float, min_repo_size: int = 0) \ -> Tuple[List[str], List[int], List[int]]: # ([name], [index]) filtered_repos = repo_names if min_repo_size > 0: filtered_repos = [ repo for repo in filtered_repos if len(data_by_repo[repo]) >= min_repo_size ] while True: repo_count = random.randint(1, min(len(filtered_repos), max_repos)) chosen_repos = random.choices(filtered_repos, k=repo_count) sample_size = sum(len(data_by_repo[name]) for name in chosen_repos) if 0.8 * target_size <= sample_size <= 1.1 * target_size: # Keep sampling until we get something with appropriate size. break extra_train_indices = [] split_indices = [] for name in chosen_repos: indices = data_by_repo[name].copy() random.shuffle(indices) split_size = int(len(indices) * extra_train_portion) extra_train_indices += indices[:split_size] split_indices += indices[split_size:] return chosen_repos, split_indices, extra_train_indices dev_repos, dev_split, extra_train_dev_split = create_excluded_split( test_size, args.max_test_repos, args.extra_train_portion) for repo_name in dev_repos: del data_by_repo[repo_name] test_repos, test_split, extra_train_test_split = create_excluded_split( test_size, args.max_test_repos, args.extra_train_portion) excluded_indices = set(dev_split + extra_train_dev_split + test_split + extra_train_test_split) # Training set: all the remaining stuff. train_split = [ idx for idx in range(len(data)) if idx not in excluded_indices ] train_split.sort(key=lambda i: scores[i] ) # sort training indices according to competency score extra_train_split = extra_train_dev_split + extra_train_test_split splits = { "train": train_split, "valid": dev_split, "test": test_split, "train_extra": extra_train_split, } with (output_dir / "split_indices.pkl").open("wb") as f: pickle.dump(splits, f) def write_files(folder_path: str, sentence_fn: Callable[[int], Tuple[str, str]]): for key, indices in splits.items(): with open(os.path.join(folder_path, f"{key}.txt"), "w") as f: for idx in tqdm(indices, desc=f"Writing {key} set", leave=False): src, tgt = sentence_fn(idx) ex = src, tgt, data[idx].var_names, scores[idx], data[ idx].repo, data[idx].sha output = OutputData.encode(*ex) # assert tuple(OutputData.decode(output)) == ex f.write(output) f.write("\n") print(f"{key.capitalize()} set written") write_files(output_dir, lambda idx: data[idx][:2]) for key, names in [("valid", dev_repos), ("test", test_repos)]: with (output_dir / f"{key}_repos.txt").open("w") as f: f.write("\n".join(names)) if not (output_dir / "vocab.model").exists(): # Write out training text and train SentencePiece model. train_text_path = output_dir / "train_text.txt" with train_text_path.open("w") as f: for idx in tqdm(train_split, desc="Writing training text"): src_tokens = data[idx].decompiled_code.split(TOKEN_SEP) new_src_tokens = [] for token in src_tokens: if token in data[idx].var_names: var1, var2 = data[idx].var_names[token] new_src_tokens += [var1, var2] else: new_src_tokens.append(token) f.write(" ".join(new_src_tokens) + "\n") f.write(data[idx].original_code.replace(TOKEN_SEP, " ") + "\n") spm_train_args = { "input": train_text_path, "model_prefix": output_dir / "vocab", "vocab_size": args.vocab_size, } spm.SentencePieceTrainer.Train(" ".join( f"--{name}={str(value)}" for name, value in spm_train_args.items())) if args.encode_spm: # Encode all sentences with the trained SP model. with flutes.safe_pool(args.n_procs, state_class=EncodeSPMState, init_args=(output_dir / "vocab.model", )) as pool: processed_code = list( pool.imap(EncodeSPMState.encode_spm, map( lambda ex: (ex.decompiled_code, ex.original_code), tqdm(data, desc="Encoding with SPM")), chunksize=args.block_size)) write_files(output_dir / "tokenized", lambda idx: processed_code[idx])