示例#1
0
def main():
    args = Args()
    flutes.register_ipython_excepthook()
    with open(args.binary_list) as f:
        binaries = [line.split() for line in f if line]
    output_dir = Path(args.output_dir)

    file_descriptions = [
        FileDescription(name="binaries",
                        folder="binaries",
                        pattern=args.binary_path_pattern,
                        filename="{sha}"),
        FileDescription(name="decompiled output",
                        folder="decompiled",
                        pattern=args.decompiled_path_pattern,
                        filename="{sha}.jsonl"),
        # FileDescription(name="matched functions",
        #                 folder="matched_funcs", pattern=args.matched_func_path_pattern, filename="{sha}.jsonl"),
        FileDescription(name="preprocessed code",
                        folder="code",
                        pattern=args.preprocessed_code_path_pattern,
                        filename="{sha}.c"),
    ]

    for desc in file_descriptions:
        output_folder = output_dir / desc.folder
        output_folder.mkdir(exist_ok=True, parents=True)
        for repo, sha in tqdm(binaries, desc=f"Copying {desc.name}"):
            shutil.copy(desc.pattern.format(repo=repo, sha=sha),
                        output_folder / desc.filename.format(sha=sha))
示例#2
0
def main():
    args = Args()
    print(args.to_string())
    random.seed(args.random_seed)
    if args.pdb:
        flutes.register_ipython_excepthook()

    splits = {
        "train_extra": "train_extra/",
        "valid": "dev/",
        "test": "test/",
        "train": ".",
    }

    tranx_data_dir = Path(args.tranx_data_dir)
    text_data_dir = Path(args.text_data_dir)
    output_dir = Path(args.output_path)
    output_dir.mkdir(parents=True, exist_ok=True)

    with flutes.safe_pool(args.n_procs, state_class=State) as pool:
        for key, path in splits.items():
            pool.broadcast(State.clear)
            files = [
                file for file in sorted(Path(tranx_data_dir / path).iterdir())
                if file.name.startswith("data") and file.suffix == ".pkl"
            ]
            for _ in tqdm(pool.imap_unordered(State.process_data, files),
                          total=len(files)):
                pass
            canonical_tgt_map: Dict[str, str] = {}
            total_size = 0
            for state_map in pool.get_states():
                total_size += len(state_map.canonical_map)
                canonical_tgt_map.update(state_map.canonical_map)
            assert total_size == len(canonical_tgt_map)
            print(f"{key.capitalize()} set processed")
            in_path = text_data_dir / f"{key}.txt"
            out_path = output_dir / f"{key}.txt"
            not_found = set()
            with in_path.open("r") as fin, out_path.open("w") as fout:
                progress = tqdm(fin, total=total_size)
                for line in progress:
                    if not line: continue
                    example = InputData.decode(line)
                    encoded_output = canonical_tgt_map.get(
                        example.original_code, None)
                    if encoded_output is None:
                        not_found.add(example.original_code)
                        progress.set_postfix(not_found=len(not_found))
                    else:
                        del canonical_tgt_map[example.original_code]
                        fout.write(encoded_output)
                        fout.write("\n")
                print(
                    f"{len(not_found)} not found, {len(canonical_tgt_map)} leftover"
                )
                for encoded_output in canonical_tgt_map.values():
                    fout.write(encoded_output)
                    fout.write("\n")
                print(f"{key.capitalize()} set written to {out_path}")
示例#3
0
def main():
    flutes.register_ipython_excepthook()
    random.seed(ghcc.__MAGIC__)
    np.random.seed(ghcc.__MAGIC__)

    repo_info = analyze_logs(args.log_file)
    changed = changed_repos(repo_info)

    # Sample 100 failed repositories.
    repos_with_fail = [
        repo for repo, info in repo_info.items()
        if info["n_partial"][-1] < info["n_total"][-1]
    ]
    samples = np.random.choice(len(repos_with_fail), 100, replace=False)
    _repo_samples = [repos_with_fail[x] for x in samples]

    # Remove repositories with more than 50 Makefiles.
    repo_samples = []
    for repo in _repo_samples:
        _, val = repo_info[repo]["n_total"][-1]
        if val <= 50:
            repo_samples.append(repo)
        else:
            print(f"{repo} contains {val} Makefiles, skipping")

    # Clone the repositories.
    for repo in tqdm(repo_samples, desc="Cloning repos"):
        owner, name = repo.split("/")
        ghcc.clone(owner, name, "test_compile")

    # Write repository information into a CSV file.
    # Each line is a separate Makefile.
    db = ghcc.RepoDB()
    with open("repo_samples.csv", "w") as f:
        writer = csv.writer(f)
        writer.writerow(["Repo", "Makefile", "Status", "Failed Reason?"])

        for repo in tqdm(repo_samples, desc="Writing CSV"):
            makefiles = ghcc.find_makefiles(os.path.join("test_compile", repo))
            owner, name = repo.split("/")
            entry = db.get(owner, name)
            success_makefiles = set()
            for makefile_info in entry['makefiles']:
                directory = makefile_info["directory"]
                directory = "/".join([owner, name] + directory.split("/")[4:])
                success_makefiles.add(directory)
            for makefile in makefiles:
                directory = "/".join(makefile.split("/")[1:])
                status = "" if directory in success_makefiles else "Failed"
                writer.writerow([repo, directory, status])
                print(repo, directory, status)
示例#4
0
def main():
    args = Args()
    print(args.to_string())
    if args.pdb:
        flutes.register_ipython_excepthook()
    random.seed(args.random_seed)

    files = [
        file for file in Path(args.data_dir).iterdir() if file.suffix == ".pkl"
    ]
    manager = flutes.ProgressBarManager()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    with flutes.safe_pool(args.n_procs,
                          state_class=PoolState,
                          init_args=(manager.proxy, )) as pool:

        def _wrap_iter() -> Iterator[str]:
            progress = manager.proxy.new(total=len(files))
            for tokens in pool.imap_unordered(PoolState.collect_tokens, files):
                yield from tokens
                progress.update(1)
            ident_tokens = Counter()
            for id_counts in pool.get_states():
                ident_tokens.update(id_counts)
            progress = manager.proxy.new(total=sum(ident_tokens.values()))
            for token, count in ident_tokens.items():
                yield from [token] * count
                progress.update(count)

        sampled_tokens = sample(args.sample_size, _wrap_iter())

    with (output_dir / "tokens.txt").open("w") as f:
        random.shuffle(sampled_tokens)
        for token in sampled_tokens:
            f.write(token)
            f.write("\n")

    spm_train_args = {
        "input": output_dir / "tokens.txt",
        "model_prefix": output_dir / "vocab",
        "vocab_size": args.vocab_size,
        "split_by_whitespace": 0,  # false
        "remove_extra_whitespaces": 0,  # false
        # "input_sentence_size": 10 ** 8,
    }
    spm.SentencePieceTrainer.Train(" ".join(
        f"--{name}={str(value)}" for name, value in spm_train_args.items()))
示例#5
0
def main() -> None:
    if args.n_procs == 0:
        # Only do this on the single-threaded case.
        flutes.register_ipython_excepthook()
    flutes.log(f"Running with {args.n_procs} worker processes", "warning")

    # Check for/create output directories
    make_directory(args.output_dir)

    # Use RAM-backed memory for tmp if available
    if os.path.exists('/dev/shm'):
        tempfile.tempdir = '/dev/shm'

    flutes.set_log_file(args.log_file)
    write_pseudo_registry()

    # Obtain a list of all binaries
    binaries = get_binary_mapping(args.binary_mapping_cache_file)

    flutes.log(f"{len(binaries)} binaries to process.")
    file_count = 0
    db = ghcc.BinaryDB()

    with flutes.safe_pool(args.n_procs, closing=[db]) as pool:
        decompile_fn: Callable[[BinaryInfo],
                               DecompilationResult] = functools.partial(
                                   decompile,
                                   output_dir=args.output_dir,
                                   binary_dir=args.binaries_dir,
                                   timeout=args.timeout)
        for result in pool.imap_unordered(decompile_fn,
                                          iter_binaries(db, binaries)):
            file_count += 1
            if result is not None:
                db.add_binary(result.info["repo_owner"],
                              result.info["repo_name"], result.hash,
                              result.status is DecompilationStatus.Success)
            if file_count % 100 == 0:
                flutes.log(f"Processed {file_count} binaries",
                           force_console=True)
示例#6
0
def main() -> None:
    if not ghcc.utils.verify_docker_image(verbose=True):
        exit(1)

    sys.setrecursionlimit(10000)
    args = Arguments()
    if args.pdb:
        flutes.register_ipython_excepthook()
        if args.n_procs == 0:
            globals()['match_functions'] = match_functions.__wrapped__

    if not args.verbose:
        flutes.set_logging_level("quiet", console=True, file=False)
    flutes.set_log_file(args.log_file)
    flutes.log("Running with arguments:\n" + args.to_string(),
               force_console=True)

    if os.path.exists(args.temp_dir):
        flutes.log(
            f"Removing contents of temporary folder '{args.temp_dir}'...",
            "warning",
            force_console=True)
        ghcc.utils.run_docker_command(
            ["rm", "-rf", "/usr/src/*"],
            user=0,
            directory_mapping={args.temp_dir: "/usr/src"})

    db = ghcc.MatchFuncDB()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    manager = flutes.ProgressBarManager(
        verbose=args.show_progress,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}{postfix}]")
    with flutes.safe_pool(args.n_procs, closing=[db, manager]) as pool:
        iterator, stats = iter_repos(
            db,
            args.max_repos,
            skip_to=args.skip_to,
            cache_path=args.repo_binary_info_cache_path,
            force_reprocess=args.force_reprocess)
        match_fn: Callable[[RepoInfo], Result] = functools.partial(
            match_functions,
            archive_folder=args.archive_dir,
            temp_folder=args.temp_dir,
            decompile_folder=args.decompile_dir,
            use_fake_libc_headers=args.use_fake_libc_headers,
            preprocess_timeout=args.preprocess_timeout,
            progress_bar=manager.proxy)

        repo_count = stats.repo_count
        func_count = stats.func_count
        func_without_ast_count = stats.func_without_ast_count
        for result in pool.imap_unordered(match_fn, iterator):
            if result is None:
                # Exception occurred.
                if args.exit_on_exception:
                    flutes.log(
                        f"Exception occurred, exiting because 'exit_on_exception' is True",
                        "warning")
                    break
                continue

            # Write the matched functions to disk.
            result: Result  # type: ignore
            repo_dir = output_dir / result.repo_owner / result.repo_name
            repo_dir.mkdir(parents=True, exist_ok=True)
            with (repo_dir / "matched_funcs.jsonl").open("w") as f:
                for matched_func in result.matched_functions:
                    f.write(
                        json.dumps(matched_func._asdict(),
                                   separators=(',', ':')) + "\n")
            for sha, code in result.preprocessed_original_code.items():
                with (repo_dir / f"{sha}.c").open("w") as f:
                    pos = code.rfind(ghcc.parse.FAKE_LIBC_END_LINE)
                    if pos != -1:
                        code = code[(pos +
                                     len(ghcc.parse.FAKE_LIBC_END_LINE)):]
                    f.write(code)

            if args.write_db:
                db.add_repo(
                    result.repo_owner,
                    result.repo_name,
                    files_found=result.files_found,
                    funcs_found=result.functions_found,
                    funcs_matched=len(result.matched_functions),
                    funcs_matched_without_ast=result.funcs_without_asts)

            repo_count += 1
            func_count += len(result.matched_functions)
            func_without_ast_count += result.funcs_without_asts
            if repo_count % 100 == 0:
                flutes.log(
                    f"Processed {repo_count} repositories, {func_count} functions matched "
                    f"({func_without_ast_count} w/o AST)",
                    force_console=True)
示例#7
0
def main():
    args = Args()
    flutes.register_ipython_excepthook()
    src_data, tgt_data, additional_data = read_pairs(
        args.data_file, return_additional_data=True)
    # names = args.hyp_names.split(",")
    # hyp_paths = args.hyp_files.split(",")
    # overlap_paths = args.overlap_score_files.split(",")
    systems = [
        # (name, is_finetune?)
        ("Seq2seq-D", False),
        ("Seq2seq-O", False),
        # ("Seq2seq-D+Finetune", True),
        # ("Seq2seq-O+Finetune", True),
        ("TranX-BPE-D-Greedy", False),
        ("TranX-BPE-D-Beam5", False),
        ("TranX-BPE-O-Greedy", False),
        ("TranX-BPE-O-Beam5", False),
        # ("TranX-t2t-D-Greedy+Finetune", True),
        # ("TranX-t2t-D-Beam5+Finetune", True),
        # ("TranX-t2t-O-Greedy+Finetune", True),
        # ("TranX-t2t-O-Beam5+Finetune", True),
        ("Seq2seq-D-Small", False),
        ("Tree2tree", False),
        ("Tree2tree-BPE", False),
        ("TranX-Small-D-Greedy", False),
    ]
    names = [name for name, _ in systems]
    hyp_paths = [
        "outputs_canon_new_decomp/test_default.hyp.orig",
        "outputs_canon_new_orig/test_default.hyp.orig",
        # "outputs_decomp_varname_finetune/test_default.hyp.orig",
        # "outputs_orig_varname_finetune/test_default.hyp.orig",
        get_tranx_path("decompiled", beam_size=1),
        get_tranx_path("original", beam_size=1),
        get_tranx_path("decompiled", beam_size=5),
        get_tranx_path("original", beam_size=5),
        # get_tranx_path("decompiled", beam_size=1, finetune=True),
        # get_tranx_path("original", beam_size=1, finetune=True),
        # get_tranx_path("decompiled", beam_size=5, finetune=True),
        # get_tranx_path("original", beam_size=5, finetune=True),
        "outputs_canon_new_decomp_small/test_default.hyp.orig",
        "../Tree2Tree-master/test_lr1e-3_139570.pkl.txt",
        "../Tree2Tree-master/test_bpe_64000.pkl.txt",
        get_tranx_path("decompiled",
                       beam_size=1,
                       train_file="tranx_data_small"),
    ]
    # overlap_paths = ["data_canonical/" + ("overlap_test.txt" if not is_finetune else "overlap_extra_test.txt")
    #                  for _, is_finetune in systems]
    scores = [float(x) for x in read_lines("test_overlap.txt", verbose=False)]
    overlap_paths = ["test_overlap.txt" for _ in systems]
    assert len(names) == len(hyp_paths)
    print("\n".join(f"{name}:   {path}"
                    for name, path in zip(names, hyp_paths)))
    hyp_data = {}
    for name, hyp_path in zip(names, hyp_paths):
        hyp_data[name] = list(
            read_lines(hyp_path, verbose=False, skip_empty=False))
    if len(overlap_paths) == 1:
        overlap_paths = overlap_paths * len(names)
    assert len(overlap_paths) == len(names)
    overlap_scores = {name: scores for name in names}

    assert len(src_data) == len(tgt_data) == len(additional_data)
    assert len(overlap_scores) == len(names) == len(hyp_data)
    assert all(
        len(data) == len(src_data) == len(scores)
        for data, scores in zip(hyp_data.values(), overlap_scores.values()))

    pickle_file = Path(args.pickle_file)
    with pickle_file.open("wb") as f:
        obj = (names, src_data, tgt_data, hyp_data, overlap_scores,
               additional_data)
        pickle.dump(obj, f)
    print(f"File written to {pickle_file}")

    # # Separate pickles for examples with overlap scores <= 0.5 or > 0.5
    # similar_indices = [idx for idx, score in enumerate(scores) if score > 0.5]
    # dissimilar_indices = [idx for idx, score in enumerate(scores) if score <= 0.5]
    # for indices, suffix in [(similar_indices, "similar"), (dissimilar_indices, "dissimilar")]:
    #     file = pickle_file.with_name(pickle_file.name + "_" + suffix)
    #     with file.open("wb") as f:
    #         obj = (
    #             names,
    #             [src_data[i] for i in indices],
    #             [tgt_data[i] for i in indices],
    #             {name: [data[i] for i in indices] for name, data in hyp_data.items()},
    #             {name: [data[i] for i in indices] for name, data in overlap_scores.items()},
    #             [additional_data[i] for i in indices],
    #         )
    #         pickle.dump(obj, f)
    #     print(f"File written to {file}")

    return

    bleu_scores = []
    edit_scores = []
    for src, ref, hyp in zip(tqdm(src_data, desc="Computing scores"), ref_data,
                             hyp_data):
        bleu4 = tx.evals.sentence_bleu([ref], hyp, max_order=4, smooth=True)
        bleu8 = tx.evals.sentence_bleu([ref], hyp, max_order=8, smooth=True)
        bleu_scores.append((bleu4, bleu8))
        edit_neu, edit_pos, edit_neg = batch_compute_edit_score(
            src,
            ref,
            hyp,
            reward_and_penalty=[(0.0, 0.0), (0.5, 0.0), (0.0, -0.5)])
        edit_scores.append((edit_neu, edit_pos, edit_neg))

    indices = sorted(range(len(src_data)), key=lambda i: bleu_scores[i][0])
    with open(args.output_file, "w") as f:
        for idx in tqdm(indices, desc="Writing output"):
            src = src_data[idx]
            ref = ref_data[idx]
            hyp = hyp_data[idx]
            overlap_score = overlap_scores[idx]
            ref, hyp = color_match(ref, hyp)
            bleu4, bleu8 = bleu_scores[idx]
            f.write(
                colored(
                    f"Example {idx} (BLEU4 = {bleu4:.2f}, BLEU8 = {bleu8:.2f}), "
                    f"overlap with train = ", "yellow") +
                colored(f"{overlap_score:.3f}",
                        "red" if overlap_score > 0.8 else "yellow") + "\n")
            f.write(colored("Source:     ", "blue") + src + "\n")
            f.write(colored("Target:     ", "blue") + ref + "\n")
            f.write(colored("Prediction: ", "blue") + hyp + "\n")
            f.write("\n")

    bleu4 = tx.evals.corpus_bleu([[tgt] for tgt in tgt_data],
                                 hyp_data,
                                 max_order=4)
    bleu8 = tx.evals.corpus_bleu([[tgt] for tgt in tgt_data],
                                 hyp_data,
                                 max_order=8)
    print(f"BLEU4 = {bleu4:.2f}, BLEU8 = {bleu8:.2f}")
示例#8
0
def main() -> None:
    if not ghcc.utils.verify_docker_image(verbose=True):
        exit(1)

    args = Arguments()
    if args.n_procs == 0:
        # Only do this on the single-threaded case.
        flutes.register_ipython_excepthook()
    flutes.set_log_file(args.log_file)
    flutes.set_logging_level(args.logging_level, console=True, file=False)
    flutes.log("Running with arguments:\n" + args.to_string(),
               force_console=True)

    if os.path.exists(args.clone_folder):
        flutes.log(
            f"Removing contents of clone folder '{args.clone_folder}'...",
            "warning",
            force_console=True)
        ghcc.utils.run_docker_command(
            ["rm", "-rf", "/usr/src/*"],
            user=0,
            directory_mapping={args.clone_folder: "/usr/src"})

    flutes.log("Crawling starts...", "warning", force_console=True)
    db = ghcc.RepoDB()
    libraries: Set[str] = set()
    if args.record_libraries is not None and os.path.exists(
            args.record_libraries):
        with open(args.record_libraries, "r") as f:
            libraries = set(f.read().split())

    def flush_libraries():
        if args.record_libraries is not None:
            with open(args.record_libraries, "w") as f:
                f.write("\n".join(libraries))

    with flutes.safe_pool(args.n_procs, closing=[db, flush_libraries]) as pool:
        iterator = iter_repos(db, args.repo_list_file, args.max_repos)
        pipeline_fn: Callable[
            [RepoInfo], Optional[PipelineResult]] = functools.partial(
                clone_and_compile,
                clone_folder=args.clone_folder,
                binary_folder=args.binary_folder,
                archive_folder=args.archive_folder,
                recursive_clone=args.recursive_clone,
                clone_timeout=args.clone_timeout,
                compile_timeout=args.compile_timeout,
                force_reclone=args.force_reclone,
                force_recompile=args.force_recompile,
                docker_batch_compile=args.docker_batch_compile,
                max_archive_size=args.max_archive_size,
                compression_type=args.compression_type,
                record_libraries=(args.record_libraries is not None),
                record_metainfo=args.record_metainfo,
                gcc_override_flags=args.gcc_override_flags)
        repo_count = 0
        meta_info = MetaInfo()
        for result in pool.imap_unordered(pipeline_fn, iterator):
            repo_count += 1
            if repo_count % 100 == 0:
                flutes.log(f"Processed {repo_count} repositories",
                           force_console=True)
            if result is None:
                continue
            repo_owner, repo_name = result.repo_info.repo_owner, result.repo_info.repo_name
            if args.write_db:
                if result.clone_success is not None or result.repo_info.db_result is None:
                    # There's probably an inconsistency somewhere if we didn't clone while `db_result` is None.
                    # To prevent more errors, just add it to the DB.
                    repo_size = result.repo_size or -1  # a value of zero is probably also wrong
                    clone_success = result.clone_success if result.clone_success is not None else True
                    db.add_repo(repo_owner,
                                repo_name,
                                clone_success,
                                repo_size=repo_size)
                    flutes.log(f"Added {repo_owner}/{repo_name} to DB")
                if result.makefiles is not None:
                    update_result = db.update_makefile(
                        repo_owner,
                        repo_name,
                        result.makefiles,
                        ignore_length_mismatch=True)
                    if not update_result:
                        flutes.log(
                            f"Makefiles of {repo_owner}/{repo_name} not saved to DB due to Unicode encoding "
                            f"errors", "error")
            if result.libraries is not None:
                libraries.update(result.libraries)
                if repo_count % 10 == 0:  # flush every 10 repos
                    flush_libraries()

            if args.record_metainfo:
                meta_info.add_repo(result)
                if repo_count % 100 == 0:
                    flutes.log(repr(meta_info), force_console=True)

        flutes.log(repr(meta_info), force_console=True)
示例#9
0
def test_register_ipython_excepthook() -> None:
    flutes.register_ipython_excepthook()
示例#10
0
def main() -> None:
    args = Args()
    if args.pdb:
        flutes.register_ipython_excepthook()
    if args.debug:
        print(
            colored(
                "Running in debug mode: no checkpoints or logs will be saved",
                "yellow"))

    with open(args.config_file) as f:
        config: Dict[str, Any] = cotra.utils.load_yaml(f)
    if args.extra_config is not None:
        import ast
        extra_config = ast.literal_eval(args.extra_config)
        cotra.utils.merge_dict(config, extra_config)
    # Do some validation before running time-consuming processes.
    assert os.path.exists(config["data"]["training_set"])
    assert all(
        os.path.exists(path) for path in config["data"]["valid_sets"].values())
    assert all(
        os.path.exists(path) for path in config["data"]["test_sets"].values())

    tx.run.make_deterministic(config["random_seed"])
    print(f"Random seed set to {config['random_seed']}")

    output_dir = Path(args.output_dir)
    if not args.debug and output_dir.exists(
    ) and args.run_mode == "train" and not args.force:
        print(
            colored(
                f"Output folder '{str(output_dir)}' exists, use --force to overwrite."
            ))
        sys.exit(1)

    # Load data
    eval_datasets: Dict[str, Dict[str, cotra.CodeData]] = {}
    hparams = copy.deepcopy(config["data"]["hparams"])
    vocab = cotra.utils.Vocab.load(config["data"]["vocab_file"])
    train_dataset = None
    if args.run_mode == "train":
        train_dataset = cotra.CodeData(
            path=config["data"]["training_set"],
            vocab=vocab,
            hparams={
                **hparams,
                "shuffle": True,
                "curriculum": {
                    "enabled": args.curriculum
                },
                "verbose": config["data"]["verbose"],
                "num_parallel_calls": args.n_procs,
                # "lazy_strategy": "all",
                # "cache_strategy": "none",
                # "shuffle_buffer_size": 4096,
            })
    eval_splits: Dict[str, Dict[str, str]] = {
        "valid": config["data"]["valid_sets"],
        "test": config["data"]["test_sets"],
    }
    for split, paths in eval_splits.items():
        eval_datasets[split] = {
            f"{split}_{name}": cotra.CodeData(
                path=path,
                vocab=vocab,
                hparams={
                    **hparams,
                    "shuffle": False,
                    "curriculum": {
                        "enabled": False
                    },
                    "batch_size": config["training"]["test_batch_size"],
                    "max_dataset_size": 500 if split == "valid" else -1,
                    # Evaluation must use truncate mode -- no example in the test set should be discarded.
                    "length_filter_mode": "truncate",
                })
            for name, path in paths.items()
        }
    batching_strategy = cotra.CustomBatchingStrategy(
        config["training"]["max_batch_tokens"])
    print("Dataset initialized")

    # Create model and optimizer
    model = cotra.Seq2seq(vocab, hparams=config["model"])
    model = ModelWrapper(model,
                         beam_width=config["inference"]["beam_width"],
                         length_penalty=config["inference"]["length_penalty"])

    lr_config = config["lr_scheduler"]
    optim = torch.optim.Adam(model.parameters(),
                             lr=lr_config["lr"],
                             betas=(0.9, 0.997),
                             eps=1e-9)
    scheduler = cotra.utils.get_lr_scheduler(optim, lr_config)
    print("Model constructed")

    training_config = config["training"]
    test_output_path = output_dir / "test.output"
    valid_set = next(iter(
        eval_datasets["valid"].values()))  # only validate on first valid split
    actions_on_plateau = []
    if lr_config.get("lr_decay", 0.0) > 0.0:
        actions_on_plateau.append(tx.run.action.scale_lr(
            lr_config["lr_decay"]))
    executor = Executor(
        model=model,
        train_data=train_dataset,
        valid_data=valid_set,
        test_data=eval_datasets["test"],
        batching_strategy=batching_strategy,
        optimizer=optim,
        lr_scheduler=scheduler,
        grad_clip=training_config.get("grad_clip", None),
        log_destination=[
            sys.stdout, *([output_dir / "log.txt"] if not args.debug else [])
        ],
        log_every=cond.iteration(training_config["display_steps"]),
        validate_every=cond.iteration(training_config["eval_steps"]),
        stop_training_on=cond.iteration(training_config["max_train_steps"]),
        train_metrics=[
            ("loss", metric.RunningAverage(20)),  # average over 20 iterations
            ("lr", metric.LR(optim))
        ],
        log_format="{time} : Epoch {epoch:2d} @ {iteration:6d}it "
        "({progress}%, {speed}), lr = {lr:.3e}, loss = {loss:.3f}",
        valid_metrics=cotra.utils.WordPieceBLEU(
            vocab,
            decode=True,
            encoding="spm",
            sample_output_per=len(valid_set) // 10),
        plateau_condition=cond.validation(better=False),
        action_on_plateau=actions_on_plateau,
        test_metrics=[
            cotra.utils.FileBLEU(vocab, test_output_path, encoding="spm"),
            ("unofficial_bleu",
             cotra.utils.WordPieceBLEU(vocab, decode=True, encoding="spm"))
        ],
        valid_log_format="{time} : Epoch {epoch}, {split} BLEU = {BLEU:.3f}",
        test_progress_log_format=(
            "{time} : Evaluating on {split} ({progress}%, {speed}), "
            "unofficial BLEU = {unofficial_bleu:.2f}"),
        validate_mode='predict',
        checkpoint_dir=(args.output_dir if not args.debug else None),
        save_every=(cond.validation(better=True) if not args.debug else None),
        max_to_keep=5,
        show_live_progress=True,
    )

    executor.write_log(pprint.pformat(config))
    all_datasets = {
        "train": train_dataset,
        **{
            key: value
            for datasets in eval_datasets.values() for key, value in datasets.items(
            )
        }
    }
    try:
        executor.write_log("Data size: " + repr({
            key: len(split)
            for key, split in all_datasets.items() if split is not None
        }))
    except TypeError:
        pass
    n_params = sum(param.numel() for param in model.parameters())
    executor.write_log(f"#Parameters: {str(n_params)}")

    if args.curriculum:

        @executor.on_event(cond.Event.Epoch, 'begin')
        def update_dataset_steps(exc: Executor):
            assert train_dataset is not None
            train_dataset.update_steps(exc.status["iteration"])
            exc._train_tracker.set_size(len(train_dataset))
            exc.write_log(
                f"Epoch {exc.status['epoch']}, competency updated to {train_dataset.competency * 100:6.2f}%"
            )

    executor.write_log(f"Begin running with {args.run_mode} mode")
    if args.run_mode == "train":
        if args.load_checkpoint:
            load_path = executor.load(path=args.checkpoint_path,
                                      allow_failure=True)
            # if load_path is not None:
            #     executor.test(eval_datasets["valid"])

        executor.train()
    else:
        executor.load(path=args.checkpoint_path, load_training_state=False)
        for name, dataset in eval_datasets[args.run_mode].items():
            executor.test({name: dataset})
            # Manually rename the test output file.
            os.rename(
                str(test_output_path) + ".hyp",
                output_dir / args.test_output_file.format(split=name))
示例#11
0
def main():
    args = Args()
    print(args.to_string())
    random.seed(args.random_seed)
    if args.pdb:
        flutes.register_ipython_excepthook()

    output_dir = Path(args.output_path)
    output_dir.mkdir(parents=True, exist_ok=True)
    data = flutes.LazyList(read_data(args.data_dir))

    @flutes.cache(output_dir / "scores.pkl", name="competency scores")
    def compute_scores():
        # Gather word counts and compute "competency" scores for each example, for use in curriculum learning.
        with flutes.safe_pool(args.n_procs,
                              state_class=CountWordsState) as pool:
            for _ in pool.imap_unordered(
                    CountWordsState.
                    count_words,  # use only the source sentence
                    map(lambda ex: ex.decompiled_code,
                        tqdm(data, desc="Counting words")),
                    chunksize=args.block_size):
                pass
            word_counter = Counter()
            for state in pool.get_states():
                word_counter.update(state.counter)
        total_words = sum(word_counter.values())
        word_scores = {
            w: -math.log(c / total_words)
            for w, c in word_counter.items()
        }
        with flutes.safe_pool(args.n_procs,
                              state_class=ComputeScoreState,
                              init_args=(word_scores, )) as pool:
            scores = list(
                pool.imap(ComputeScoreState.compute_score,
                          map(lambda ex: ex.decompiled_code,
                              tqdm(data, desc="Computing scores")),
                          chunksize=args.block_size))
        return scores

    scores = compute_scores()

    # Generate data splits.
    test_size = args.test_split_size or int(
        len(data) * args.test_split_portion)
    # Dev/Test set: Repositories excluded in training set.
    data_by_repo = defaultdict(list)
    for idx, ex in enumerate(data):
        data_by_repo[ex.repo].append(idx)
    repo_names = list(data_by_repo.keys())

    def create_excluded_split(target_size: int, max_repos: int, extra_train_portion: float, min_repo_size: int = 0) \
            -> Tuple[List[str], List[int], List[int]]:
        # ([name], [index])
        filtered_repos = repo_names
        if min_repo_size > 0:
            filtered_repos = [
                repo for repo in filtered_repos
                if len(data_by_repo[repo]) >= min_repo_size
            ]
        while True:
            repo_count = random.randint(1, min(len(filtered_repos), max_repos))
            chosen_repos = random.choices(filtered_repos, k=repo_count)
            sample_size = sum(len(data_by_repo[name]) for name in chosen_repos)
            if 0.8 * target_size <= sample_size <= 1.1 * target_size:
                # Keep sampling until we get something with appropriate size.
                break
        extra_train_indices = []
        split_indices = []
        for name in chosen_repos:
            indices = data_by_repo[name].copy()
            random.shuffle(indices)
            split_size = int(len(indices) * extra_train_portion)
            extra_train_indices += indices[:split_size]
            split_indices += indices[split_size:]
        return chosen_repos, split_indices, extra_train_indices

    dev_repos, dev_split, extra_train_dev_split = create_excluded_split(
        test_size, args.max_test_repos, args.extra_train_portion)
    for repo_name in dev_repos:
        del data_by_repo[repo_name]
    test_repos, test_split, extra_train_test_split = create_excluded_split(
        test_size, args.max_test_repos, args.extra_train_portion)
    excluded_indices = set(dev_split + extra_train_dev_split + test_split +
                           extra_train_test_split)

    # Training set: all the remaining stuff.
    train_split = [
        idx for idx in range(len(data)) if idx not in excluded_indices
    ]
    train_split.sort(key=lambda i: scores[i]
                     )  # sort training indices according to competency score
    extra_train_split = extra_train_dev_split + extra_train_test_split
    splits = {
        "train": train_split,
        "valid": dev_split,
        "test": test_split,
        "train_extra": extra_train_split,
    }
    with (output_dir / "split_indices.pkl").open("wb") as f:
        pickle.dump(splits, f)

    def write_files(folder_path: str, sentence_fn: Callable[[int],
                                                            Tuple[str, str]]):
        for key, indices in splits.items():
            with open(os.path.join(folder_path, f"{key}.txt"), "w") as f:
                for idx in tqdm(indices,
                                desc=f"Writing {key} set",
                                leave=False):
                    src, tgt = sentence_fn(idx)
                    ex = src, tgt, data[idx].var_names, scores[idx], data[
                        idx].repo, data[idx].sha
                    output = OutputData.encode(*ex)
                    # assert tuple(OutputData.decode(output)) == ex
                    f.write(output)
                    f.write("\n")
            print(f"{key.capitalize()} set written")

    write_files(output_dir, lambda idx: data[idx][:2])
    for key, names in [("valid", dev_repos), ("test", test_repos)]:
        with (output_dir / f"{key}_repos.txt").open("w") as f:
            f.write("\n".join(names))

    if not (output_dir / "vocab.model").exists():
        # Write out training text and train SentencePiece model.
        train_text_path = output_dir / "train_text.txt"
        with train_text_path.open("w") as f:
            for idx in tqdm(train_split, desc="Writing training text"):
                src_tokens = data[idx].decompiled_code.split(TOKEN_SEP)
                new_src_tokens = []
                for token in src_tokens:
                    if token in data[idx].var_names:
                        var1, var2 = data[idx].var_names[token]
                        new_src_tokens += [var1, var2]
                    else:
                        new_src_tokens.append(token)
                f.write(" ".join(new_src_tokens) + "\n")
                f.write(data[idx].original_code.replace(TOKEN_SEP, " ") + "\n")
        spm_train_args = {
            "input": train_text_path,
            "model_prefix": output_dir / "vocab",
            "vocab_size": args.vocab_size,
        }
        spm.SentencePieceTrainer.Train(" ".join(
            f"--{name}={str(value)}"
            for name, value in spm_train_args.items()))

    if args.encode_spm:
        # Encode all sentences with the trained SP model.
        with flutes.safe_pool(args.n_procs,
                              state_class=EncodeSPMState,
                              init_args=(output_dir /
                                         "vocab.model", )) as pool:
            processed_code = list(
                pool.imap(EncodeSPMState.encode_spm,
                          map(
                              lambda ex:
                              (ex.decompiled_code, ex.original_code),
                              tqdm(data, desc="Encoding with SPM")),
                          chunksize=args.block_size))

        write_files(output_dir / "tokenized", lambda idx: processed_code[idx])