def main(): args = Args() flutes.register_ipython_excepthook() with open(args.binary_list) as f: binaries = [line.split() for line in f if line] output_dir = Path(args.output_dir) file_descriptions = [ FileDescription(name="binaries", folder="binaries", pattern=args.binary_path_pattern, filename="{sha}"), FileDescription(name="decompiled output", folder="decompiled", pattern=args.decompiled_path_pattern, filename="{sha}.jsonl"), # FileDescription(name="matched functions", # folder="matched_funcs", pattern=args.matched_func_path_pattern, filename="{sha}.jsonl"), FileDescription(name="preprocessed code", folder="code", pattern=args.preprocessed_code_path_pattern, filename="{sha}.c"), ] for desc in file_descriptions: output_folder = output_dir / desc.folder output_folder.mkdir(exist_ok=True, parents=True) for repo, sha in tqdm(binaries, desc=f"Copying {desc.name}"): shutil.copy(desc.pattern.format(repo=repo, sha=sha), output_folder / desc.filename.format(sha=sha))
def main(): args = Args() print(args.to_string()) random.seed(args.random_seed) if args.pdb: flutes.register_ipython_excepthook() splits = { "train_extra": "train_extra/", "valid": "dev/", "test": "test/", "train": ".", } tranx_data_dir = Path(args.tranx_data_dir) text_data_dir = Path(args.text_data_dir) output_dir = Path(args.output_path) output_dir.mkdir(parents=True, exist_ok=True) with flutes.safe_pool(args.n_procs, state_class=State) as pool: for key, path in splits.items(): pool.broadcast(State.clear) files = [ file for file in sorted(Path(tranx_data_dir / path).iterdir()) if file.name.startswith("data") and file.suffix == ".pkl" ] for _ in tqdm(pool.imap_unordered(State.process_data, files), total=len(files)): pass canonical_tgt_map: Dict[str, str] = {} total_size = 0 for state_map in pool.get_states(): total_size += len(state_map.canonical_map) canonical_tgt_map.update(state_map.canonical_map) assert total_size == len(canonical_tgt_map) print(f"{key.capitalize()} set processed") in_path = text_data_dir / f"{key}.txt" out_path = output_dir / f"{key}.txt" not_found = set() with in_path.open("r") as fin, out_path.open("w") as fout: progress = tqdm(fin, total=total_size) for line in progress: if not line: continue example = InputData.decode(line) encoded_output = canonical_tgt_map.get( example.original_code, None) if encoded_output is None: not_found.add(example.original_code) progress.set_postfix(not_found=len(not_found)) else: del canonical_tgt_map[example.original_code] fout.write(encoded_output) fout.write("\n") print( f"{len(not_found)} not found, {len(canonical_tgt_map)} leftover" ) for encoded_output in canonical_tgt_map.values(): fout.write(encoded_output) fout.write("\n") print(f"{key.capitalize()} set written to {out_path}")
def main(): flutes.register_ipython_excepthook() random.seed(ghcc.__MAGIC__) np.random.seed(ghcc.__MAGIC__) repo_info = analyze_logs(args.log_file) changed = changed_repos(repo_info) # Sample 100 failed repositories. repos_with_fail = [ repo for repo, info in repo_info.items() if info["n_partial"][-1] < info["n_total"][-1] ] samples = np.random.choice(len(repos_with_fail), 100, replace=False) _repo_samples = [repos_with_fail[x] for x in samples] # Remove repositories with more than 50 Makefiles. repo_samples = [] for repo in _repo_samples: _, val = repo_info[repo]["n_total"][-1] if val <= 50: repo_samples.append(repo) else: print(f"{repo} contains {val} Makefiles, skipping") # Clone the repositories. for repo in tqdm(repo_samples, desc="Cloning repos"): owner, name = repo.split("/") ghcc.clone(owner, name, "test_compile") # Write repository information into a CSV file. # Each line is a separate Makefile. db = ghcc.RepoDB() with open("repo_samples.csv", "w") as f: writer = csv.writer(f) writer.writerow(["Repo", "Makefile", "Status", "Failed Reason?"]) for repo in tqdm(repo_samples, desc="Writing CSV"): makefiles = ghcc.find_makefiles(os.path.join("test_compile", repo)) owner, name = repo.split("/") entry = db.get(owner, name) success_makefiles = set() for makefile_info in entry['makefiles']: directory = makefile_info["directory"] directory = "/".join([owner, name] + directory.split("/")[4:]) success_makefiles.add(directory) for makefile in makefiles: directory = "/".join(makefile.split("/")[1:]) status = "" if directory in success_makefiles else "Failed" writer.writerow([repo, directory, status]) print(repo, directory, status)
def main(): args = Args() print(args.to_string()) if args.pdb: flutes.register_ipython_excepthook() random.seed(args.random_seed) files = [ file for file in Path(args.data_dir).iterdir() if file.suffix == ".pkl" ] manager = flutes.ProgressBarManager() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) with flutes.safe_pool(args.n_procs, state_class=PoolState, init_args=(manager.proxy, )) as pool: def _wrap_iter() -> Iterator[str]: progress = manager.proxy.new(total=len(files)) for tokens in pool.imap_unordered(PoolState.collect_tokens, files): yield from tokens progress.update(1) ident_tokens = Counter() for id_counts in pool.get_states(): ident_tokens.update(id_counts) progress = manager.proxy.new(total=sum(ident_tokens.values())) for token, count in ident_tokens.items(): yield from [token] * count progress.update(count) sampled_tokens = sample(args.sample_size, _wrap_iter()) with (output_dir / "tokens.txt").open("w") as f: random.shuffle(sampled_tokens) for token in sampled_tokens: f.write(token) f.write("\n") spm_train_args = { "input": output_dir / "tokens.txt", "model_prefix": output_dir / "vocab", "vocab_size": args.vocab_size, "split_by_whitespace": 0, # false "remove_extra_whitespaces": 0, # false # "input_sentence_size": 10 ** 8, } spm.SentencePieceTrainer.Train(" ".join( f"--{name}={str(value)}" for name, value in spm_train_args.items()))
def main() -> None: if args.n_procs == 0: # Only do this on the single-threaded case. flutes.register_ipython_excepthook() flutes.log(f"Running with {args.n_procs} worker processes", "warning") # Check for/create output directories make_directory(args.output_dir) # Use RAM-backed memory for tmp if available if os.path.exists('/dev/shm'): tempfile.tempdir = '/dev/shm' flutes.set_log_file(args.log_file) write_pseudo_registry() # Obtain a list of all binaries binaries = get_binary_mapping(args.binary_mapping_cache_file) flutes.log(f"{len(binaries)} binaries to process.") file_count = 0 db = ghcc.BinaryDB() with flutes.safe_pool(args.n_procs, closing=[db]) as pool: decompile_fn: Callable[[BinaryInfo], DecompilationResult] = functools.partial( decompile, output_dir=args.output_dir, binary_dir=args.binaries_dir, timeout=args.timeout) for result in pool.imap_unordered(decompile_fn, iter_binaries(db, binaries)): file_count += 1 if result is not None: db.add_binary(result.info["repo_owner"], result.info["repo_name"], result.hash, result.status is DecompilationStatus.Success) if file_count % 100 == 0: flutes.log(f"Processed {file_count} binaries", force_console=True)
def main() -> None: if not ghcc.utils.verify_docker_image(verbose=True): exit(1) sys.setrecursionlimit(10000) args = Arguments() if args.pdb: flutes.register_ipython_excepthook() if args.n_procs == 0: globals()['match_functions'] = match_functions.__wrapped__ if not args.verbose: flutes.set_logging_level("quiet", console=True, file=False) flutes.set_log_file(args.log_file) flutes.log("Running with arguments:\n" + args.to_string(), force_console=True) if os.path.exists(args.temp_dir): flutes.log( f"Removing contents of temporary folder '{args.temp_dir}'...", "warning", force_console=True) ghcc.utils.run_docker_command( ["rm", "-rf", "/usr/src/*"], user=0, directory_mapping={args.temp_dir: "/usr/src"}) db = ghcc.MatchFuncDB() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) manager = flutes.ProgressBarManager( verbose=args.show_progress, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}{postfix}]") with flutes.safe_pool(args.n_procs, closing=[db, manager]) as pool: iterator, stats = iter_repos( db, args.max_repos, skip_to=args.skip_to, cache_path=args.repo_binary_info_cache_path, force_reprocess=args.force_reprocess) match_fn: Callable[[RepoInfo], Result] = functools.partial( match_functions, archive_folder=args.archive_dir, temp_folder=args.temp_dir, decompile_folder=args.decompile_dir, use_fake_libc_headers=args.use_fake_libc_headers, preprocess_timeout=args.preprocess_timeout, progress_bar=manager.proxy) repo_count = stats.repo_count func_count = stats.func_count func_without_ast_count = stats.func_without_ast_count for result in pool.imap_unordered(match_fn, iterator): if result is None: # Exception occurred. if args.exit_on_exception: flutes.log( f"Exception occurred, exiting because 'exit_on_exception' is True", "warning") break continue # Write the matched functions to disk. result: Result # type: ignore repo_dir = output_dir / result.repo_owner / result.repo_name repo_dir.mkdir(parents=True, exist_ok=True) with (repo_dir / "matched_funcs.jsonl").open("w") as f: for matched_func in result.matched_functions: f.write( json.dumps(matched_func._asdict(), separators=(',', ':')) + "\n") for sha, code in result.preprocessed_original_code.items(): with (repo_dir / f"{sha}.c").open("w") as f: pos = code.rfind(ghcc.parse.FAKE_LIBC_END_LINE) if pos != -1: code = code[(pos + len(ghcc.parse.FAKE_LIBC_END_LINE)):] f.write(code) if args.write_db: db.add_repo( result.repo_owner, result.repo_name, files_found=result.files_found, funcs_found=result.functions_found, funcs_matched=len(result.matched_functions), funcs_matched_without_ast=result.funcs_without_asts) repo_count += 1 func_count += len(result.matched_functions) func_without_ast_count += result.funcs_without_asts if repo_count % 100 == 0: flutes.log( f"Processed {repo_count} repositories, {func_count} functions matched " f"({func_without_ast_count} w/o AST)", force_console=True)
def main(): args = Args() flutes.register_ipython_excepthook() src_data, tgt_data, additional_data = read_pairs( args.data_file, return_additional_data=True) # names = args.hyp_names.split(",") # hyp_paths = args.hyp_files.split(",") # overlap_paths = args.overlap_score_files.split(",") systems = [ # (name, is_finetune?) ("Seq2seq-D", False), ("Seq2seq-O", False), # ("Seq2seq-D+Finetune", True), # ("Seq2seq-O+Finetune", True), ("TranX-BPE-D-Greedy", False), ("TranX-BPE-D-Beam5", False), ("TranX-BPE-O-Greedy", False), ("TranX-BPE-O-Beam5", False), # ("TranX-t2t-D-Greedy+Finetune", True), # ("TranX-t2t-D-Beam5+Finetune", True), # ("TranX-t2t-O-Greedy+Finetune", True), # ("TranX-t2t-O-Beam5+Finetune", True), ("Seq2seq-D-Small", False), ("Tree2tree", False), ("Tree2tree-BPE", False), ("TranX-Small-D-Greedy", False), ] names = [name for name, _ in systems] hyp_paths = [ "outputs_canon_new_decomp/test_default.hyp.orig", "outputs_canon_new_orig/test_default.hyp.orig", # "outputs_decomp_varname_finetune/test_default.hyp.orig", # "outputs_orig_varname_finetune/test_default.hyp.orig", get_tranx_path("decompiled", beam_size=1), get_tranx_path("original", beam_size=1), get_tranx_path("decompiled", beam_size=5), get_tranx_path("original", beam_size=5), # get_tranx_path("decompiled", beam_size=1, finetune=True), # get_tranx_path("original", beam_size=1, finetune=True), # get_tranx_path("decompiled", beam_size=5, finetune=True), # get_tranx_path("original", beam_size=5, finetune=True), "outputs_canon_new_decomp_small/test_default.hyp.orig", "../Tree2Tree-master/test_lr1e-3_139570.pkl.txt", "../Tree2Tree-master/test_bpe_64000.pkl.txt", get_tranx_path("decompiled", beam_size=1, train_file="tranx_data_small"), ] # overlap_paths = ["data_canonical/" + ("overlap_test.txt" if not is_finetune else "overlap_extra_test.txt") # for _, is_finetune in systems] scores = [float(x) for x in read_lines("test_overlap.txt", verbose=False)] overlap_paths = ["test_overlap.txt" for _ in systems] assert len(names) == len(hyp_paths) print("\n".join(f"{name}: {path}" for name, path in zip(names, hyp_paths))) hyp_data = {} for name, hyp_path in zip(names, hyp_paths): hyp_data[name] = list( read_lines(hyp_path, verbose=False, skip_empty=False)) if len(overlap_paths) == 1: overlap_paths = overlap_paths * len(names) assert len(overlap_paths) == len(names) overlap_scores = {name: scores for name in names} assert len(src_data) == len(tgt_data) == len(additional_data) assert len(overlap_scores) == len(names) == len(hyp_data) assert all( len(data) == len(src_data) == len(scores) for data, scores in zip(hyp_data.values(), overlap_scores.values())) pickle_file = Path(args.pickle_file) with pickle_file.open("wb") as f: obj = (names, src_data, tgt_data, hyp_data, overlap_scores, additional_data) pickle.dump(obj, f) print(f"File written to {pickle_file}") # # Separate pickles for examples with overlap scores <= 0.5 or > 0.5 # similar_indices = [idx for idx, score in enumerate(scores) if score > 0.5] # dissimilar_indices = [idx for idx, score in enumerate(scores) if score <= 0.5] # for indices, suffix in [(similar_indices, "similar"), (dissimilar_indices, "dissimilar")]: # file = pickle_file.with_name(pickle_file.name + "_" + suffix) # with file.open("wb") as f: # obj = ( # names, # [src_data[i] for i in indices], # [tgt_data[i] for i in indices], # {name: [data[i] for i in indices] for name, data in hyp_data.items()}, # {name: [data[i] for i in indices] for name, data in overlap_scores.items()}, # [additional_data[i] for i in indices], # ) # pickle.dump(obj, f) # print(f"File written to {file}") return bleu_scores = [] edit_scores = [] for src, ref, hyp in zip(tqdm(src_data, desc="Computing scores"), ref_data, hyp_data): bleu4 = tx.evals.sentence_bleu([ref], hyp, max_order=4, smooth=True) bleu8 = tx.evals.sentence_bleu([ref], hyp, max_order=8, smooth=True) bleu_scores.append((bleu4, bleu8)) edit_neu, edit_pos, edit_neg = batch_compute_edit_score( src, ref, hyp, reward_and_penalty=[(0.0, 0.0), (0.5, 0.0), (0.0, -0.5)]) edit_scores.append((edit_neu, edit_pos, edit_neg)) indices = sorted(range(len(src_data)), key=lambda i: bleu_scores[i][0]) with open(args.output_file, "w") as f: for idx in tqdm(indices, desc="Writing output"): src = src_data[idx] ref = ref_data[idx] hyp = hyp_data[idx] overlap_score = overlap_scores[idx] ref, hyp = color_match(ref, hyp) bleu4, bleu8 = bleu_scores[idx] f.write( colored( f"Example {idx} (BLEU4 = {bleu4:.2f}, BLEU8 = {bleu8:.2f}), " f"overlap with train = ", "yellow") + colored(f"{overlap_score:.3f}", "red" if overlap_score > 0.8 else "yellow") + "\n") f.write(colored("Source: ", "blue") + src + "\n") f.write(colored("Target: ", "blue") + ref + "\n") f.write(colored("Prediction: ", "blue") + hyp + "\n") f.write("\n") bleu4 = tx.evals.corpus_bleu([[tgt] for tgt in tgt_data], hyp_data, max_order=4) bleu8 = tx.evals.corpus_bleu([[tgt] for tgt in tgt_data], hyp_data, max_order=8) print(f"BLEU4 = {bleu4:.2f}, BLEU8 = {bleu8:.2f}")
def main() -> None: if not ghcc.utils.verify_docker_image(verbose=True): exit(1) args = Arguments() if args.n_procs == 0: # Only do this on the single-threaded case. flutes.register_ipython_excepthook() flutes.set_log_file(args.log_file) flutes.set_logging_level(args.logging_level, console=True, file=False) flutes.log("Running with arguments:\n" + args.to_string(), force_console=True) if os.path.exists(args.clone_folder): flutes.log( f"Removing contents of clone folder '{args.clone_folder}'...", "warning", force_console=True) ghcc.utils.run_docker_command( ["rm", "-rf", "/usr/src/*"], user=0, directory_mapping={args.clone_folder: "/usr/src"}) flutes.log("Crawling starts...", "warning", force_console=True) db = ghcc.RepoDB() libraries: Set[str] = set() if args.record_libraries is not None and os.path.exists( args.record_libraries): with open(args.record_libraries, "r") as f: libraries = set(f.read().split()) def flush_libraries(): if args.record_libraries is not None: with open(args.record_libraries, "w") as f: f.write("\n".join(libraries)) with flutes.safe_pool(args.n_procs, closing=[db, flush_libraries]) as pool: iterator = iter_repos(db, args.repo_list_file, args.max_repos) pipeline_fn: Callable[ [RepoInfo], Optional[PipelineResult]] = functools.partial( clone_and_compile, clone_folder=args.clone_folder, binary_folder=args.binary_folder, archive_folder=args.archive_folder, recursive_clone=args.recursive_clone, clone_timeout=args.clone_timeout, compile_timeout=args.compile_timeout, force_reclone=args.force_reclone, force_recompile=args.force_recompile, docker_batch_compile=args.docker_batch_compile, max_archive_size=args.max_archive_size, compression_type=args.compression_type, record_libraries=(args.record_libraries is not None), record_metainfo=args.record_metainfo, gcc_override_flags=args.gcc_override_flags) repo_count = 0 meta_info = MetaInfo() for result in pool.imap_unordered(pipeline_fn, iterator): repo_count += 1 if repo_count % 100 == 0: flutes.log(f"Processed {repo_count} repositories", force_console=True) if result is None: continue repo_owner, repo_name = result.repo_info.repo_owner, result.repo_info.repo_name if args.write_db: if result.clone_success is not None or result.repo_info.db_result is None: # There's probably an inconsistency somewhere if we didn't clone while `db_result` is None. # To prevent more errors, just add it to the DB. repo_size = result.repo_size or -1 # a value of zero is probably also wrong clone_success = result.clone_success if result.clone_success is not None else True db.add_repo(repo_owner, repo_name, clone_success, repo_size=repo_size) flutes.log(f"Added {repo_owner}/{repo_name} to DB") if result.makefiles is not None: update_result = db.update_makefile( repo_owner, repo_name, result.makefiles, ignore_length_mismatch=True) if not update_result: flutes.log( f"Makefiles of {repo_owner}/{repo_name} not saved to DB due to Unicode encoding " f"errors", "error") if result.libraries is not None: libraries.update(result.libraries) if repo_count % 10 == 0: # flush every 10 repos flush_libraries() if args.record_metainfo: meta_info.add_repo(result) if repo_count % 100 == 0: flutes.log(repr(meta_info), force_console=True) flutes.log(repr(meta_info), force_console=True)
def test_register_ipython_excepthook() -> None: flutes.register_ipython_excepthook()
def main() -> None: args = Args() if args.pdb: flutes.register_ipython_excepthook() if args.debug: print( colored( "Running in debug mode: no checkpoints or logs will be saved", "yellow")) with open(args.config_file) as f: config: Dict[str, Any] = cotra.utils.load_yaml(f) if args.extra_config is not None: import ast extra_config = ast.literal_eval(args.extra_config) cotra.utils.merge_dict(config, extra_config) # Do some validation before running time-consuming processes. assert os.path.exists(config["data"]["training_set"]) assert all( os.path.exists(path) for path in config["data"]["valid_sets"].values()) assert all( os.path.exists(path) for path in config["data"]["test_sets"].values()) tx.run.make_deterministic(config["random_seed"]) print(f"Random seed set to {config['random_seed']}") output_dir = Path(args.output_dir) if not args.debug and output_dir.exists( ) and args.run_mode == "train" and not args.force: print( colored( f"Output folder '{str(output_dir)}' exists, use --force to overwrite." )) sys.exit(1) # Load data eval_datasets: Dict[str, Dict[str, cotra.CodeData]] = {} hparams = copy.deepcopy(config["data"]["hparams"]) vocab = cotra.utils.Vocab.load(config["data"]["vocab_file"]) train_dataset = None if args.run_mode == "train": train_dataset = cotra.CodeData( path=config["data"]["training_set"], vocab=vocab, hparams={ **hparams, "shuffle": True, "curriculum": { "enabled": args.curriculum }, "verbose": config["data"]["verbose"], "num_parallel_calls": args.n_procs, # "lazy_strategy": "all", # "cache_strategy": "none", # "shuffle_buffer_size": 4096, }) eval_splits: Dict[str, Dict[str, str]] = { "valid": config["data"]["valid_sets"], "test": config["data"]["test_sets"], } for split, paths in eval_splits.items(): eval_datasets[split] = { f"{split}_{name}": cotra.CodeData( path=path, vocab=vocab, hparams={ **hparams, "shuffle": False, "curriculum": { "enabled": False }, "batch_size": config["training"]["test_batch_size"], "max_dataset_size": 500 if split == "valid" else -1, # Evaluation must use truncate mode -- no example in the test set should be discarded. "length_filter_mode": "truncate", }) for name, path in paths.items() } batching_strategy = cotra.CustomBatchingStrategy( config["training"]["max_batch_tokens"]) print("Dataset initialized") # Create model and optimizer model = cotra.Seq2seq(vocab, hparams=config["model"]) model = ModelWrapper(model, beam_width=config["inference"]["beam_width"], length_penalty=config["inference"]["length_penalty"]) lr_config = config["lr_scheduler"] optim = torch.optim.Adam(model.parameters(), lr=lr_config["lr"], betas=(0.9, 0.997), eps=1e-9) scheduler = cotra.utils.get_lr_scheduler(optim, lr_config) print("Model constructed") training_config = config["training"] test_output_path = output_dir / "test.output" valid_set = next(iter( eval_datasets["valid"].values())) # only validate on first valid split actions_on_plateau = [] if lr_config.get("lr_decay", 0.0) > 0.0: actions_on_plateau.append(tx.run.action.scale_lr( lr_config["lr_decay"])) executor = Executor( model=model, train_data=train_dataset, valid_data=valid_set, test_data=eval_datasets["test"], batching_strategy=batching_strategy, optimizer=optim, lr_scheduler=scheduler, grad_clip=training_config.get("grad_clip", None), log_destination=[ sys.stdout, *([output_dir / "log.txt"] if not args.debug else []) ], log_every=cond.iteration(training_config["display_steps"]), validate_every=cond.iteration(training_config["eval_steps"]), stop_training_on=cond.iteration(training_config["max_train_steps"]), train_metrics=[ ("loss", metric.RunningAverage(20)), # average over 20 iterations ("lr", metric.LR(optim)) ], log_format="{time} : Epoch {epoch:2d} @ {iteration:6d}it " "({progress}%, {speed}), lr = {lr:.3e}, loss = {loss:.3f}", valid_metrics=cotra.utils.WordPieceBLEU( vocab, decode=True, encoding="spm", sample_output_per=len(valid_set) // 10), plateau_condition=cond.validation(better=False), action_on_plateau=actions_on_plateau, test_metrics=[ cotra.utils.FileBLEU(vocab, test_output_path, encoding="spm"), ("unofficial_bleu", cotra.utils.WordPieceBLEU(vocab, decode=True, encoding="spm")) ], valid_log_format="{time} : Epoch {epoch}, {split} BLEU = {BLEU:.3f}", test_progress_log_format=( "{time} : Evaluating on {split} ({progress}%, {speed}), " "unofficial BLEU = {unofficial_bleu:.2f}"), validate_mode='predict', checkpoint_dir=(args.output_dir if not args.debug else None), save_every=(cond.validation(better=True) if not args.debug else None), max_to_keep=5, show_live_progress=True, ) executor.write_log(pprint.pformat(config)) all_datasets = { "train": train_dataset, **{ key: value for datasets in eval_datasets.values() for key, value in datasets.items( ) } } try: executor.write_log("Data size: " + repr({ key: len(split) for key, split in all_datasets.items() if split is not None })) except TypeError: pass n_params = sum(param.numel() for param in model.parameters()) executor.write_log(f"#Parameters: {str(n_params)}") if args.curriculum: @executor.on_event(cond.Event.Epoch, 'begin') def update_dataset_steps(exc: Executor): assert train_dataset is not None train_dataset.update_steps(exc.status["iteration"]) exc._train_tracker.set_size(len(train_dataset)) exc.write_log( f"Epoch {exc.status['epoch']}, competency updated to {train_dataset.competency * 100:6.2f}%" ) executor.write_log(f"Begin running with {args.run_mode} mode") if args.run_mode == "train": if args.load_checkpoint: load_path = executor.load(path=args.checkpoint_path, allow_failure=True) # if load_path is not None: # executor.test(eval_datasets["valid"]) executor.train() else: executor.load(path=args.checkpoint_path, load_training_state=False) for name, dataset in eval_datasets[args.run_mode].items(): executor.test({name: dataset}) # Manually rename the test output file. os.rename( str(test_output_path) + ".hyp", output_dir / args.test_output_file.format(split=name))
def main(): args = Args() print(args.to_string()) random.seed(args.random_seed) if args.pdb: flutes.register_ipython_excepthook() output_dir = Path(args.output_path) output_dir.mkdir(parents=True, exist_ok=True) data = flutes.LazyList(read_data(args.data_dir)) @flutes.cache(output_dir / "scores.pkl", name="competency scores") def compute_scores(): # Gather word counts and compute "competency" scores for each example, for use in curriculum learning. with flutes.safe_pool(args.n_procs, state_class=CountWordsState) as pool: for _ in pool.imap_unordered( CountWordsState. count_words, # use only the source sentence map(lambda ex: ex.decompiled_code, tqdm(data, desc="Counting words")), chunksize=args.block_size): pass word_counter = Counter() for state in pool.get_states(): word_counter.update(state.counter) total_words = sum(word_counter.values()) word_scores = { w: -math.log(c / total_words) for w, c in word_counter.items() } with flutes.safe_pool(args.n_procs, state_class=ComputeScoreState, init_args=(word_scores, )) as pool: scores = list( pool.imap(ComputeScoreState.compute_score, map(lambda ex: ex.decompiled_code, tqdm(data, desc="Computing scores")), chunksize=args.block_size)) return scores scores = compute_scores() # Generate data splits. test_size = args.test_split_size or int( len(data) * args.test_split_portion) # Dev/Test set: Repositories excluded in training set. data_by_repo = defaultdict(list) for idx, ex in enumerate(data): data_by_repo[ex.repo].append(idx) repo_names = list(data_by_repo.keys()) def create_excluded_split(target_size: int, max_repos: int, extra_train_portion: float, min_repo_size: int = 0) \ -> Tuple[List[str], List[int], List[int]]: # ([name], [index]) filtered_repos = repo_names if min_repo_size > 0: filtered_repos = [ repo for repo in filtered_repos if len(data_by_repo[repo]) >= min_repo_size ] while True: repo_count = random.randint(1, min(len(filtered_repos), max_repos)) chosen_repos = random.choices(filtered_repos, k=repo_count) sample_size = sum(len(data_by_repo[name]) for name in chosen_repos) if 0.8 * target_size <= sample_size <= 1.1 * target_size: # Keep sampling until we get something with appropriate size. break extra_train_indices = [] split_indices = [] for name in chosen_repos: indices = data_by_repo[name].copy() random.shuffle(indices) split_size = int(len(indices) * extra_train_portion) extra_train_indices += indices[:split_size] split_indices += indices[split_size:] return chosen_repos, split_indices, extra_train_indices dev_repos, dev_split, extra_train_dev_split = create_excluded_split( test_size, args.max_test_repos, args.extra_train_portion) for repo_name in dev_repos: del data_by_repo[repo_name] test_repos, test_split, extra_train_test_split = create_excluded_split( test_size, args.max_test_repos, args.extra_train_portion) excluded_indices = set(dev_split + extra_train_dev_split + test_split + extra_train_test_split) # Training set: all the remaining stuff. train_split = [ idx for idx in range(len(data)) if idx not in excluded_indices ] train_split.sort(key=lambda i: scores[i] ) # sort training indices according to competency score extra_train_split = extra_train_dev_split + extra_train_test_split splits = { "train": train_split, "valid": dev_split, "test": test_split, "train_extra": extra_train_split, } with (output_dir / "split_indices.pkl").open("wb") as f: pickle.dump(splits, f) def write_files(folder_path: str, sentence_fn: Callable[[int], Tuple[str, str]]): for key, indices in splits.items(): with open(os.path.join(folder_path, f"{key}.txt"), "w") as f: for idx in tqdm(indices, desc=f"Writing {key} set", leave=False): src, tgt = sentence_fn(idx) ex = src, tgt, data[idx].var_names, scores[idx], data[ idx].repo, data[idx].sha output = OutputData.encode(*ex) # assert tuple(OutputData.decode(output)) == ex f.write(output) f.write("\n") print(f"{key.capitalize()} set written") write_files(output_dir, lambda idx: data[idx][:2]) for key, names in [("valid", dev_repos), ("test", test_repos)]: with (output_dir / f"{key}_repos.txt").open("w") as f: f.write("\n".join(names)) if not (output_dir / "vocab.model").exists(): # Write out training text and train SentencePiece model. train_text_path = output_dir / "train_text.txt" with train_text_path.open("w") as f: for idx in tqdm(train_split, desc="Writing training text"): src_tokens = data[idx].decompiled_code.split(TOKEN_SEP) new_src_tokens = [] for token in src_tokens: if token in data[idx].var_names: var1, var2 = data[idx].var_names[token] new_src_tokens += [var1, var2] else: new_src_tokens.append(token) f.write(" ".join(new_src_tokens) + "\n") f.write(data[idx].original_code.replace(TOKEN_SEP, " ") + "\n") spm_train_args = { "input": train_text_path, "model_prefix": output_dir / "vocab", "vocab_size": args.vocab_size, } spm.SentencePieceTrainer.Train(" ".join( f"--{name}={str(value)}" for name, value in spm_train_args.items())) if args.encode_spm: # Encode all sentences with the trained SP model. with flutes.safe_pool(args.n_procs, state_class=EncodeSPMState, init_args=(output_dir / "vocab.model", )) as pool: processed_code = list( pool.imap(EncodeSPMState.encode_spm, map( lambda ex: (ex.decompiled_code, ex.original_code), tqdm(data, desc="Encoding with SPM")), chunksize=args.block_size)) write_files(output_dir / "tokenized", lambda idx: processed_code[idx])