コード例 #1
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_safe_pool() -> None:
    seq = list(range(10000))
    target = list(map(sqr, seq))  # sequential
    with mp.Pool(1) as pool:
        pool_type = type(pool)

    def file_obj():
        nonlocal call_count
        call_count += 1

    call_count = 0
    with flutes.safe_pool(0, closing=[file_obj]) as pool:
        check_iterator(pool.imap(sqr, seq), target)
    assert not isinstance(pool, pool_type)
    assert call_count == 1

    file_obj = NonCallableMagicMock()
    file_obj.mock_add_spec(["close"])
    with flutes.safe_pool(2, closing=[file_obj],
                          suppress_exceptions=True) as pool:
        result = list(pool.imap(sqr, seq))
        raise ValueError  # should swallow exceptions
    assert isinstance(pool, pool_type)
    assert result == target
    file_obj.close.assert_called_once()

    with pytest.raises(KeyboardInterrupt):
        with flutes.safe_pool(2, closing=[file_obj],
                              suppress_exceptions=True) as pool:
            raise KeyboardInterrupt

    manager = ContextManagerMock()
    with flutes.safe_pool(2, closing=[manager]) as pool:
        raise ValueError
    assert manager.state == 2
コード例 #2
0
 def compute_scores():
     # Gather word counts and compute "competency" scores for each example, for use in curriculum learning.
     with flutes.safe_pool(args.n_procs,
                           state_class=CountWordsState) as pool:
         for _ in pool.imap_unordered(
                 CountWordsState.
                 count_words,  # use only the source sentence
                 map(lambda ex: ex.decompiled_code,
                     tqdm(data, desc="Counting words")),
                 chunksize=args.block_size):
             pass
         word_counter = Counter()
         for state in pool.get_states():
             word_counter.update(state.counter)
     total_words = sum(word_counter.values())
     word_scores = {
         w: -math.log(c / total_words)
         for w, c in word_counter.items()
     }
     with flutes.safe_pool(args.n_procs,
                           state_class=ComputeScoreState,
                           init_args=(word_scores, )) as pool:
         scores = list(
             pool.imap(ComputeScoreState.compute_score,
                       map(lambda ex: ex.decompiled_code,
                           tqdm(data, desc="Computing scores")),
                       chunksize=args.block_size))
     return scores
コード例 #3
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_gather() -> None:
    n = 10000
    intervals = list(range(0, n + 1, 1000))
    answer = set(range(n))
    for n_procs in [0, 2]:
        with flutes.safe_pool(n_procs) as pool:
            assert set(pool.gather(gather_fn, zip(intervals,
                                                  intervals[1:]))) == answer
        with flutes.safe_pool(n_procs, state_class=PoolState2) as pool:
            assert set(
                pool.gather(PoolState2.gather_fn,
                            zip(intervals, intervals[1:]))) == answer
コード例 #4
0
def main():
    args = Args()
    print(args.to_string())
    random.seed(args.random_seed)
    if args.pdb:
        flutes.register_ipython_excepthook()

    splits = {
        "train_extra": "train_extra/",
        "valid": "dev/",
        "test": "test/",
        "train": ".",
    }

    tranx_data_dir = Path(args.tranx_data_dir)
    text_data_dir = Path(args.text_data_dir)
    output_dir = Path(args.output_path)
    output_dir.mkdir(parents=True, exist_ok=True)

    with flutes.safe_pool(args.n_procs, state_class=State) as pool:
        for key, path in splits.items():
            pool.broadcast(State.clear)
            files = [
                file for file in sorted(Path(tranx_data_dir / path).iterdir())
                if file.name.startswith("data") and file.suffix == ".pkl"
            ]
            for _ in tqdm(pool.imap_unordered(State.process_data, files),
                          total=len(files)):
                pass
            canonical_tgt_map: Dict[str, str] = {}
            total_size = 0
            for state_map in pool.get_states():
                total_size += len(state_map.canonical_map)
                canonical_tgt_map.update(state_map.canonical_map)
            assert total_size == len(canonical_tgt_map)
            print(f"{key.capitalize()} set processed")
            in_path = text_data_dir / f"{key}.txt"
            out_path = output_dir / f"{key}.txt"
            not_found = set()
            with in_path.open("r") as fin, out_path.open("w") as fout:
                progress = tqdm(fin, total=total_size)
                for line in progress:
                    if not line: continue
                    example = InputData.decode(line)
                    encoded_output = canonical_tgt_map.get(
                        example.original_code, None)
                    if encoded_output is None:
                        not_found.add(example.original_code)
                        progress.set_postfix(not_found=len(not_found))
                    else:
                        del canonical_tgt_map[example.original_code]
                        fout.write(encoded_output)
                        fout.write("\n")
                print(
                    f"{len(not_found)} not found, {len(canonical_tgt_map)} leftover"
                )
                for encoded_output in canonical_tgt_map.values():
                    fout.write(encoded_output)
                    fout.write("\n")
                print(f"{key.capitalize()} set written to {out_path}")
コード例 #5
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_stateful_pool() -> None:
    large_dict = {str(i): i for i in range(100000)}
    seq = list(map(str, range(100000)))
    target = sum(map(lambda x: int(x) + 1, seq))  # sequential

    for n_procs in [0, 2]:
        with flutes.safe_pool(n_procs,
                              state_class=PoolState,
                              init_args=(large_dict, )) as pool:
            result = sum(
                pool.imap_unordered(PoolState.convert, seq, chunksize=1000))

            # See, if you had a type checker, you wouldn't be making these mistakes.
            with pytest.raises(ValueError,
                               match="Bound methods of the pool state class"):
                _ = sum(
                    pool.imap_unordered(PoolState({}).convert,
                                        seq,
                                        chunksize=1000))
            with pytest.raises(
                    ValueError,
                    match="Only unbound methods of the pool state class"):
                _ = sum(
                    pool.imap_unordered(
                        PoolState2.generate, seq,
                        chunksize=1000))  # type: ignore[arg-type]
        assert result == target
コード例 #6
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_pool_methods() -> None:
    seq = list(range(10000))
    args = (2, )
    kwds = {"coef2": 3}
    target = [sqr(x, *args, **kwds) for x in seq]  # sequential
    for n_procs in [0, 2]:
        for state_class in [PoolState, None]:
            with flutes.safe_pool(n_procs,
                                  state_class=state_class,
                                  init_args=(None, )) as pool:
                assert pool.map(sqr, seq, args=args, kwds=kwds) == target
                check_iterator(pool.imap(sqr, seq, args=args, kwds=kwds),
                               target)
                assert sorted(
                    pool.imap_unordered(sqr, seq, args=args,
                                        kwds=kwds)) == target
                assert pool.starmap(mul, zip(seq, seq), args=args,
                                    kwds=kwds) == target
                assert pool.map_async(sqr, seq, args=args,
                                      kwds=kwds).get() == target
                assert pool.starmap_async(mul,
                                          zip(seq, seq),
                                          args=args,
                                          kwds=kwds).get() == target
                assert pool.apply(sqr, (10, 2), kwds=kwds) == 100 * 2 * 3
                assert pool.apply_async(sqr, (10, 2),
                                        kwds=kwds).get() == 100 * 2 * 3
コード例 #7
0
def test_stateful_pool_get_state() -> None:
    for n_procs in [0, 2]:
        with flutes.safe_pool(n_procs, state_class=PoolState2) as pool:
            intervals = list(range(0, 100 + 1, 5))
            pool.starmap(PoolState2.generate, zip(intervals, intervals[1:]))
            states = pool.get_states()
            assert sorted(x for state in states
                          for x in state.numbers) == list(range(100))
コード例 #8
0
def test_ProgressBarManager() -> None:
    for proc in [0, 2]:
        # Test multiprocessing in `proc = 2`
        # Test coverage in `proc = 0`
        manager = flutes.ProgressBarManager()
        with flutes.safe_pool(proc, closing=[manager]) as pool:
            fn = functools.partial(progress_bar_fn, bar=manager.proxy)
            pool.map(fn, range(10))
コード例 #9
0
def main(args):

    random.seed(args.seed)

    if (not args.output_dir.exists()) or args.overwrite:
        args.output_dir.mkdir(exist_ok=True, parents=True)
    else:
        print(f"Directory {args.output_dir} already exists.")
        sys.exit(0)

    tr_indices, vl_indices, ts_indices = make_splits_indices(
        100, args.split_ratio)
    with (args.output_dir / "file_map.txt").open("w") as f:
        print("train", file=f)
        for i in tr_indices:
            print(f"{i} ", file=f, end="")
        print(file=f)
        print("valid", file=f)
        for i in vl_indices:
            print(f"{i} ", file=f, end="")
        print(file=f)
        print("test")
        for i in ts_indices:
            print(f"{i} ", file=f, end="")
        print(file=f)

    # in/out file pairs
    files = ([(
        (args.dump_dir / f"{i}.pkl"),
        (args.output_dir / f"train_{new_idx:02}.jsonl"),
    ) for new_idx, i in enumerate(tr_indices)] + [(
        (args.dump_dir / f"{i}.pkl"),
        (args.output_dir / f"valid_{new_idx:02}.jsonl"),
    ) for new_idx, i in enumerate(vl_indices)] + [(
        (args.dump_dir / f"{i}.pkl"),
        (args.output_dir / f"test_{new_idx:02}.jsonl"),
    ) for new_idx, i in enumerate(ts_indices)])

    # files = sorted(
    #     list((args.dump_dir).glob("*.pkl")), key=lambda x: int(x.with_suffix("").name)
    # )

    total = {}
    with flutes.safe_pool(processes=args.njobs, state_class=Processor) as pool:
        for idx, _ in enumerate(
                pool.imap_unordered(Processor.process_pkl, files,
                                    chunksize=1)):
            flutes.log(f"Processed {(idx + 1)} files")

        states = pool.get_states()
        for state in states:
            total.update(state.results)
    print(
        f"Train: {sum(v for k, v in total.items() if k.startswith('train'))}")
    print(
        f"Valid: {sum(v for k, v in total.items() if k.startswith('valid'))}")
    print(f"Test:  {sum(v for k, v in total.items() if k.startswith('test'))}")
コード例 #10
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_stateful_pool_get_state() -> None:
    for n_procs in [0, 2]:
        with flutes.safe_pool(n_procs, state_class=PoolState2) as pool:
            intervals = list(range(0, 100 + 1, 5))
            pool.starmap(PoolState2.generate,
                         zip(intervals, intervals[1:]),
                         args=(1, 2))  # dummy args
            states = pool.get_states()
        assert sorted(itertools.chain.from_iterable(states)) == list(
            range(100))  # type: ignore[arg-type]
コード例 #11
0
def test_safe_pool() -> None:
    seq = list(range(10000))
    target = list(map(sqr, seq))  # sequential
    with mp.Pool(1) as pool:
        pool_type = type(pool)

    file_obj = MagicMock()
    with flutes.safe_pool(0, closing=[file_obj]) as pool:
        assert type(pool) is not pool_type
        check_iterator(pool.imap(sqr, seq), target)
    file_obj.assert_called_once()

    file_obj = NonCallableMagicMock()
    file_obj.mock_add_spec(["close"])
    with flutes.safe_pool(2, closing=[file_obj]) as pool:
        assert type(pool) is pool_type
        check_iterator(pool.imap(sqr, seq), target)
        raise ValueError  # should swallow exceptions
    file_obj.close.assert_called_once()
コード例 #12
0
def main() -> None:
    if len(sys.argv) < 2:
        print(f"Usage: python {sys.argv[0]} [file]")
        sys.exit(1)

    path = sys.argv[1]
    with flutes.work_in_progress("Read file"):
        with open(path) as f:
            sentences = []
            for line in f:
                sentences.append(line)
                if len(sentences) >= 100000:
                    break

    with flutes.work_in_progress("Parallel"):
        with flutes.safe_pool(processes=4,
                              state_class=WordCounter) as pool_stateful:
            for _ in pool_stateful.imap_unordered(WordCounter.count_words,
                                                  sentences,
                                                  chunksize=1000):
                pass
        parallel_word_counter: CounterT[str] = Counter()
        with flutes.work_in_progress("Get states"):
            states = pool_stateful.get_states()
        for state in states:
            parallel_word_counter.update(state.word_cnt)

    with flutes.work_in_progress("Naive parallel"):
        naive_word_counter: CounterT[str] = Counter()
        data_chunks = flutes.chunk(1000, sentences)
        with flutes.safe_pool(processes=4) as pool:
            for counter in pool.imap_unordered(count_words, data_chunks):
                naive_word_counter.update(counter)

    with flutes.work_in_progress("Sequential"):
        seq_word_counter: CounterT[str] = Counter()
        for sent in sentences:
            seq_word_counter.update(word.lower() for word in sent.split())

    assert seq_word_counter == naive_word_counter == parallel_word_counter
コード例 #13
0
ファイル: merge.py プロジェクト: muggin/disentangled-sum
def main(args) -> None:

    files = [(args.output_dir, args.input_dir / f"{i}.pkl")
             for i in range(0, 10000, 100)]

    total_map = {}
    with flutes.work_in_progress("Parallel"):
        with flutes.safe_pool(processes=args.njobs,
                              state_class=Worker) as pool_stateful:
            for idx, _ in enumerate(
                    pool_stateful.imap_unordered(Worker.merge,
                                                 files,
                                                 chunksize=1)):
                flutes.log(f"Processed {(idx + 1)} files")
コード例 #14
0
ファイル: test_multiproc.py プロジェクト: huzecong/flutes
def test_ProgressBarManager() -> None:
    for verbose in [False, True]:
        for proc in [0, 2]:
            # Test multiprocessing in `proc = 2`
            # Test coverage in `proc = 0`
            manager = flutes.ProgressBarManager(verbose=verbose)
            with flutes.safe_pool(proc, closing=[manager]) as pool:
                fn = functools.partial(progress_bar_fn, bar=manager.proxy)
                pool.map(fn, range(10))
                fn = functools.partial(file_progress_bar_fn, bar=manager.proxy)
                pool.map(fn, range(4))
            flutes.log(
                f"This should still show up: verbose={verbose}, proc={proc}",
                force_console=True)
コード例 #15
0
def main():
    args = Args()
    print(args.to_string())
    if args.pdb:
        flutes.register_ipython_excepthook()
    random.seed(args.random_seed)

    files = [
        file for file in Path(args.data_dir).iterdir() if file.suffix == ".pkl"
    ]
    manager = flutes.ProgressBarManager()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    with flutes.safe_pool(args.n_procs,
                          state_class=PoolState,
                          init_args=(manager.proxy, )) as pool:

        def _wrap_iter() -> Iterator[str]:
            progress = manager.proxy.new(total=len(files))
            for tokens in pool.imap_unordered(PoolState.collect_tokens, files):
                yield from tokens
                progress.update(1)
            ident_tokens = Counter()
            for id_counts in pool.get_states():
                ident_tokens.update(id_counts)
            progress = manager.proxy.new(total=sum(ident_tokens.values()))
            for token, count in ident_tokens.items():
                yield from [token] * count
                progress.update(count)

        sampled_tokens = sample(args.sample_size, _wrap_iter())

    with (output_dir / "tokens.txt").open("w") as f:
        random.shuffle(sampled_tokens)
        for token in sampled_tokens:
            f.write(token)
            f.write("\n")

    spm_train_args = {
        "input": output_dir / "tokens.txt",
        "model_prefix": output_dir / "vocab",
        "vocab_size": args.vocab_size,
        "split_by_whitespace": 0,  # false
        "remove_extra_whitespaces": 0,  # false
        # "input_sentence_size": 10 ** 8,
    }
    spm.SentencePieceTrainer.Train(" ".join(
        f"--{name}={str(value)}" for name, value in spm_train_args.items()))
コード例 #16
0
def test_pool_methods() -> None:
    seq = list(range(10000))
    target = list(map(sqr, seq))  # sequential
    for n_procs in [0, 2]:
        for state_class in [PoolState, None]:
            with flutes.safe_pool(n_procs,
                                  state_class=state_class,
                                  init_args=(None, )) as pool:
                check_iterator(pool.map(sqr, seq), target)
                check_iterator(pool.imap(sqr, seq), target)
                check_iterator(sorted(pool.imap_unordered(sqr, seq)), target)
                check_iterator(pool.starmap(operator.mul, zip(seq, seq)),
                               target)
                check_iterator(pool.map_async(sqr, seq).get(), target)
                check_iterator(
                    pool.starmap_async(operator.mul, zip(seq, seq)).get(),
                    target)
                assert pool.apply(sqr, (10, )) == 100
                assert pool.apply_async(sqr, (10, )).get() == 100
コード例 #17
0
ファイル: run_decompiler.py プロジェクト: xcode2010/ghcc
def main() -> None:
    if args.n_procs == 0:
        # Only do this on the single-threaded case.
        flutes.register_ipython_excepthook()
    flutes.log(f"Running with {args.n_procs} worker processes", "warning")

    # Check for/create output directories
    make_directory(args.output_dir)

    # Use RAM-backed memory for tmp if available
    if os.path.exists('/dev/shm'):
        tempfile.tempdir = '/dev/shm'

    flutes.set_log_file(args.log_file)
    write_pseudo_registry()

    # Obtain a list of all binaries
    binaries = get_binary_mapping(args.binary_mapping_cache_file)

    flutes.log(f"{len(binaries)} binaries to process.")
    file_count = 0
    db = ghcc.BinaryDB()

    with flutes.safe_pool(args.n_procs, closing=[db]) as pool:
        decompile_fn: Callable[[BinaryInfo],
                               DecompilationResult] = functools.partial(
                                   decompile,
                                   output_dir=args.output_dir,
                                   binary_dir=args.binaries_dir,
                                   timeout=args.timeout)
        for result in pool.imap_unordered(decompile_fn,
                                          iter_binaries(db, binaries)):
            file_count += 1
            if result is not None:
                db.add_binary(result.info["repo_owner"],
                              result.info["repo_name"], result.hash,
                              result.status is DecompilationStatus.Success)
            if file_count % 100 == 0:
                flutes.log(f"Processed {file_count} binaries",
                           force_console=True)
コード例 #18
0
def main(args) -> None:
    Proc = LegacyFilter if args.legacy else Filter

    if args.target == "paper":
        processing = Proc.filter_ids_complete  # gathers pid, in/out cite pids
    elif args.target == "citation":
        processing = Proc.filter_ids_text  # gathers pids

    if args.valid_citations is not None and args.valid_citations.exists():
        with args.valid_citations.open("rb") as f:
            d = pickle.load(f)
        dict_valid_citations = {k: True for _, pids in d.items() for k in pids}
        del d
    else:
        dict_valid_citations = {}

    files = [
        (f, dict_valid_citations, args.min_cite, args.max_cite, args.seed)
        for f in list(args.input_dir.glob("*"))
    ]

    with flutes.work_in_progress("Parallel"):
        total_results = defaultdict(list)
        with flutes.safe_pool(processes=args.njobs, state_class=Proc) as pool_stateful:
            for idx, _ in enumerate(
                pool_stateful.imap_unordered(processing, files, chunksize=10)
            ):
                flutes.log(f"Processed {(idx + 1)} files")
            with flutes.work_in_progress("Get states"):
                states = pool_stateful.get_states()
            for state in states:
                total_results.update(state.results)

    with args.output_file.open("wb") as f:
        # Dict[batchnum, List[obj]]
        pickle.dump(total_results, f)
コード例 #19
0
ファイル: match_functions.py プロジェクト: xcode2010/ghcc
def main() -> None:
    if not ghcc.utils.verify_docker_image(verbose=True):
        exit(1)

    sys.setrecursionlimit(10000)
    args = Arguments()
    if args.pdb:
        flutes.register_ipython_excepthook()
        if args.n_procs == 0:
            globals()['match_functions'] = match_functions.__wrapped__

    if not args.verbose:
        flutes.set_logging_level("quiet", console=True, file=False)
    flutes.set_log_file(args.log_file)
    flutes.log("Running with arguments:\n" + args.to_string(),
               force_console=True)

    if os.path.exists(args.temp_dir):
        flutes.log(
            f"Removing contents of temporary folder '{args.temp_dir}'...",
            "warning",
            force_console=True)
        ghcc.utils.run_docker_command(
            ["rm", "-rf", "/usr/src/*"],
            user=0,
            directory_mapping={args.temp_dir: "/usr/src"})

    db = ghcc.MatchFuncDB()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    manager = flutes.ProgressBarManager(
        verbose=args.show_progress,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}{postfix}]")
    with flutes.safe_pool(args.n_procs, closing=[db, manager]) as pool:
        iterator, stats = iter_repos(
            db,
            args.max_repos,
            skip_to=args.skip_to,
            cache_path=args.repo_binary_info_cache_path,
            force_reprocess=args.force_reprocess)
        match_fn: Callable[[RepoInfo], Result] = functools.partial(
            match_functions,
            archive_folder=args.archive_dir,
            temp_folder=args.temp_dir,
            decompile_folder=args.decompile_dir,
            use_fake_libc_headers=args.use_fake_libc_headers,
            preprocess_timeout=args.preprocess_timeout,
            progress_bar=manager.proxy)

        repo_count = stats.repo_count
        func_count = stats.func_count
        func_without_ast_count = stats.func_without_ast_count
        for result in pool.imap_unordered(match_fn, iterator):
            if result is None:
                # Exception occurred.
                if args.exit_on_exception:
                    flutes.log(
                        f"Exception occurred, exiting because 'exit_on_exception' is True",
                        "warning")
                    break
                continue

            # Write the matched functions to disk.
            result: Result  # type: ignore
            repo_dir = output_dir / result.repo_owner / result.repo_name
            repo_dir.mkdir(parents=True, exist_ok=True)
            with (repo_dir / "matched_funcs.jsonl").open("w") as f:
                for matched_func in result.matched_functions:
                    f.write(
                        json.dumps(matched_func._asdict(),
                                   separators=(',', ':')) + "\n")
            for sha, code in result.preprocessed_original_code.items():
                with (repo_dir / f"{sha}.c").open("w") as f:
                    pos = code.rfind(ghcc.parse.FAKE_LIBC_END_LINE)
                    if pos != -1:
                        code = code[(pos +
                                     len(ghcc.parse.FAKE_LIBC_END_LINE)):]
                    f.write(code)

            if args.write_db:
                db.add_repo(
                    result.repo_owner,
                    result.repo_name,
                    files_found=result.files_found,
                    funcs_found=result.functions_found,
                    funcs_matched=len(result.matched_functions),
                    funcs_matched_without_ast=result.funcs_without_asts)

            repo_count += 1
            func_count += len(result.matched_functions)
            func_without_ast_count += result.funcs_without_asts
            if repo_count % 100 == 0:
                flutes.log(
                    f"Processed {repo_count} repositories, {func_count} functions matched "
                    f"({func_without_ast_count} w/o AST)",
                    force_console=True)
コード例 #20
0
def main() -> None:
    if not ghcc.utils.verify_docker_image(verbose=True):
        exit(1)

    args = Arguments()
    if args.n_procs == 0:
        # Only do this on the single-threaded case.
        flutes.register_ipython_excepthook()
    flutes.set_log_file(args.log_file)
    flutes.set_logging_level(args.logging_level, console=True, file=False)
    flutes.log("Running with arguments:\n" + args.to_string(),
               force_console=True)

    if os.path.exists(args.clone_folder):
        flutes.log(
            f"Removing contents of clone folder '{args.clone_folder}'...",
            "warning",
            force_console=True)
        ghcc.utils.run_docker_command(
            ["rm", "-rf", "/usr/src/*"],
            user=0,
            directory_mapping={args.clone_folder: "/usr/src"})

    flutes.log("Crawling starts...", "warning", force_console=True)
    db = ghcc.RepoDB()
    libraries: Set[str] = set()
    if args.record_libraries is not None and os.path.exists(
            args.record_libraries):
        with open(args.record_libraries, "r") as f:
            libraries = set(f.read().split())

    def flush_libraries():
        if args.record_libraries is not None:
            with open(args.record_libraries, "w") as f:
                f.write("\n".join(libraries))

    with flutes.safe_pool(args.n_procs, closing=[db, flush_libraries]) as pool:
        iterator = iter_repos(db, args.repo_list_file, args.max_repos)
        pipeline_fn: Callable[
            [RepoInfo], Optional[PipelineResult]] = functools.partial(
                clone_and_compile,
                clone_folder=args.clone_folder,
                binary_folder=args.binary_folder,
                archive_folder=args.archive_folder,
                recursive_clone=args.recursive_clone,
                clone_timeout=args.clone_timeout,
                compile_timeout=args.compile_timeout,
                force_reclone=args.force_reclone,
                force_recompile=args.force_recompile,
                docker_batch_compile=args.docker_batch_compile,
                max_archive_size=args.max_archive_size,
                compression_type=args.compression_type,
                record_libraries=(args.record_libraries is not None),
                record_metainfo=args.record_metainfo,
                gcc_override_flags=args.gcc_override_flags)
        repo_count = 0
        meta_info = MetaInfo()
        for result in pool.imap_unordered(pipeline_fn, iterator):
            repo_count += 1
            if repo_count % 100 == 0:
                flutes.log(f"Processed {repo_count} repositories",
                           force_console=True)
            if result is None:
                continue
            repo_owner, repo_name = result.repo_info.repo_owner, result.repo_info.repo_name
            if args.write_db:
                if result.clone_success is not None or result.repo_info.db_result is None:
                    # There's probably an inconsistency somewhere if we didn't clone while `db_result` is None.
                    # To prevent more errors, just add it to the DB.
                    repo_size = result.repo_size or -1  # a value of zero is probably also wrong
                    clone_success = result.clone_success if result.clone_success is not None else True
                    db.add_repo(repo_owner,
                                repo_name,
                                clone_success,
                                repo_size=repo_size)
                    flutes.log(f"Added {repo_owner}/{repo_name} to DB")
                if result.makefiles is not None:
                    update_result = db.update_makefile(
                        repo_owner,
                        repo_name,
                        result.makefiles,
                        ignore_length_mismatch=True)
                    if not update_result:
                        flutes.log(
                            f"Makefiles of {repo_owner}/{repo_name} not saved to DB due to Unicode encoding "
                            f"errors", "error")
            if result.libraries is not None:
                libraries.update(result.libraries)
                if repo_count % 10 == 0:  # flush every 10 repos
                    flush_libraries()

            if args.record_metainfo:
                meta_info.add_repo(result)
                if repo_count % 100 == 0:
                    flutes.log(repr(meta_info), force_console=True)

        flutes.log(repr(meta_info), force_console=True)
コード例 #21
0
def main(args) -> None:

    pids_with_text = set()
    if args.valid_pids and args.valid_pids.exists():
        pids_with_text = set()
        with args.valid_pids.open("rb") as f:
            data = pickle.load(f)
            buf = set()
            if args.mode == "citation":
                for k, d in tqdm(enumerate(data.values()),
                                 ncols=88,
                                 ascii=True):
                    buf = buf.union(
                        set([pid for i in d for pid in i[1] + i[2]]))
                    if k % 500 == 0:
                        pids_with_text = pids_with_text.union(buf)
                        buf = set()
            elif args.mode == "paper":
                for k, d in tqdm(data.items(), ncols=88, ascii=True):
                    buf = buf.union(set([i[0] for i in d]))
                    if k % 500 == 0:
                        pids_with_text = pids_with_text.union(buf)
                        buf = set()

            # remaining one
            pids_with_text = pids_with_text.union(buf)

    flutes.log(f"# of valid pids to consider: {len(pids_with_text)}")

    if args.legacy:
        # glob takes more time than this?
        files = ((args.input_dir / f"{i}.jsonl.gz", pids_with_text)
                 for i in range(10000))
        Proc = LegacyFilter
    else:
        files = ((args.input_dir / f"pdf_parses_{i}.jsonl.gz", pids_with_text)
                 for i in range(100))
        Proc = Filter

    with flutes.work_in_progress("Parallel"):
        total_map = {}
        with flutes.safe_pool(processes=args.njobs,
                              state_class=Proc) as pool_stateful:
            for idx, _ in enumerate(
                    pool_stateful.imap_unordered(Proc.make_map,
                                                 files,
                                                 chunksize=10)):
                flutes.log(f"Processed {(idx + 1)} files")

            with flutes.work_in_progress("Get states"):
                states = pool_stateful.get_states()
            for state in states:
                # TODO: Incorporate incite number
                total_map.update(state.results)

        flutes.log(f"Total map size: {len(total_map)}")

    with args.output.open("w") as f:
        for k, v in total_map.items():
            print(k, v[0], v[1], sep="\t", file=f)

        flutes.log(f"Dumped to {args.output}")
コード例 #22
0
def main():
    # flutes.register_ipython_excepthook()

    sys.setrecursionlimit(50000)
    args = Arguments()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    flutes.log("Dataset creation start")

    db = ghcc.MatchFuncDB()
    original_code_set: Set[str] = set()
    n_duplicate = 0
    n_examples = 0
    manager = mp.Manager()
    example_queue: 'mp.Queue[QueueElem]' = manager.Queue(args.queue_size)
    with flutes.safe_pool(args.n_procs, closing=[db]) as pool:
        repos = [
            RepoInfo(entry['repo_owner'], entry['repo_name'])
            for entry in db.collection.find() if entry['funcs_matched'] > 0
        ]
        if args.max_repos is not None:
            repos = repos[:args.max_repos]
        process_fn: Callable[[RepoInfo],
                             None] = functools.partial(process,
                                                       data_dir=args.input_dir,
                                                       queue=example_queue)
        pool.map_async(process_fn, repos, error_callback=flutes.log_exception)
        end_signals = 0
        progress = tqdm.tqdm(total=len(repos))
        file_cnt = 0
        text_data = []

        def save_file():
            nonlocal file_cnt, text_data
            # Save text & AST separately
            with (output_dir / f"data_{file_cnt:03d}.pkl").open("wb") as f:
                pickle.dump(text_data, f, protocol=PICKLE_PROTOCOL)
            progress.write(f"Saved part {file_cnt:03d}")
            text_data = []
            file_cnt += 1

        while end_signals < len(repos):
            elem = example_queue.get()
            if elem == END_SIGNATURE:
                progress.update(1)
                end_signals += 1
                continue

            ex = pickle.loads(elem)
            original_code = ex[1]
            if original_code not in original_code_set:
                original_code_set.add(original_code)
                text_data.append(
                    ex)  # (decompiled, orig, var_names, repo, sha)
                n_examples += 1
            else:
                n_duplicate += 1
            if (n_examples + n_duplicate) % 100 == 0:
                progress.set_postfix(
                    {
                        "duplicate": n_duplicate,
                        "examples": n_examples
                    },
                    refresh=False)
                progress.refresh()
            if len(text_data) >= args.block_size:
                save_file()

        if len(text_data) > 0:
            save_file()
コード例 #23
0
def main():
    args = Args()
    print(args.to_string())
    random.seed(args.random_seed)
    if args.pdb:
        flutes.register_ipython_excepthook()

    output_dir = Path(args.output_path)
    output_dir.mkdir(parents=True, exist_ok=True)
    data = flutes.LazyList(read_data(args.data_dir))

    @flutes.cache(output_dir / "scores.pkl", name="competency scores")
    def compute_scores():
        # Gather word counts and compute "competency" scores for each example, for use in curriculum learning.
        with flutes.safe_pool(args.n_procs,
                              state_class=CountWordsState) as pool:
            for _ in pool.imap_unordered(
                    CountWordsState.
                    count_words,  # use only the source sentence
                    map(lambda ex: ex.decompiled_code,
                        tqdm(data, desc="Counting words")),
                    chunksize=args.block_size):
                pass
            word_counter = Counter()
            for state in pool.get_states():
                word_counter.update(state.counter)
        total_words = sum(word_counter.values())
        word_scores = {
            w: -math.log(c / total_words)
            for w, c in word_counter.items()
        }
        with flutes.safe_pool(args.n_procs,
                              state_class=ComputeScoreState,
                              init_args=(word_scores, )) as pool:
            scores = list(
                pool.imap(ComputeScoreState.compute_score,
                          map(lambda ex: ex.decompiled_code,
                              tqdm(data, desc="Computing scores")),
                          chunksize=args.block_size))
        return scores

    scores = compute_scores()

    # Generate data splits.
    test_size = args.test_split_size or int(
        len(data) * args.test_split_portion)
    # Dev/Test set: Repositories excluded in training set.
    data_by_repo = defaultdict(list)
    for idx, ex in enumerate(data):
        data_by_repo[ex.repo].append(idx)
    repo_names = list(data_by_repo.keys())

    def create_excluded_split(target_size: int, max_repos: int, extra_train_portion: float, min_repo_size: int = 0) \
            -> Tuple[List[str], List[int], List[int]]:
        # ([name], [index])
        filtered_repos = repo_names
        if min_repo_size > 0:
            filtered_repos = [
                repo for repo in filtered_repos
                if len(data_by_repo[repo]) >= min_repo_size
            ]
        while True:
            repo_count = random.randint(1, min(len(filtered_repos), max_repos))
            chosen_repos = random.choices(filtered_repos, k=repo_count)
            sample_size = sum(len(data_by_repo[name]) for name in chosen_repos)
            if 0.8 * target_size <= sample_size <= 1.1 * target_size:
                # Keep sampling until we get something with appropriate size.
                break
        extra_train_indices = []
        split_indices = []
        for name in chosen_repos:
            indices = data_by_repo[name].copy()
            random.shuffle(indices)
            split_size = int(len(indices) * extra_train_portion)
            extra_train_indices += indices[:split_size]
            split_indices += indices[split_size:]
        return chosen_repos, split_indices, extra_train_indices

    dev_repos, dev_split, extra_train_dev_split = create_excluded_split(
        test_size, args.max_test_repos, args.extra_train_portion)
    for repo_name in dev_repos:
        del data_by_repo[repo_name]
    test_repos, test_split, extra_train_test_split = create_excluded_split(
        test_size, args.max_test_repos, args.extra_train_portion)
    excluded_indices = set(dev_split + extra_train_dev_split + test_split +
                           extra_train_test_split)

    # Training set: all the remaining stuff.
    train_split = [
        idx for idx in range(len(data)) if idx not in excluded_indices
    ]
    train_split.sort(key=lambda i: scores[i]
                     )  # sort training indices according to competency score
    extra_train_split = extra_train_dev_split + extra_train_test_split
    splits = {
        "train": train_split,
        "valid": dev_split,
        "test": test_split,
        "train_extra": extra_train_split,
    }
    with (output_dir / "split_indices.pkl").open("wb") as f:
        pickle.dump(splits, f)

    def write_files(folder_path: str, sentence_fn: Callable[[int],
                                                            Tuple[str, str]]):
        for key, indices in splits.items():
            with open(os.path.join(folder_path, f"{key}.txt"), "w") as f:
                for idx in tqdm(indices,
                                desc=f"Writing {key} set",
                                leave=False):
                    src, tgt = sentence_fn(idx)
                    ex = src, tgt, data[idx].var_names, scores[idx], data[
                        idx].repo, data[idx].sha
                    output = OutputData.encode(*ex)
                    # assert tuple(OutputData.decode(output)) == ex
                    f.write(output)
                    f.write("\n")
            print(f"{key.capitalize()} set written")

    write_files(output_dir, lambda idx: data[idx][:2])
    for key, names in [("valid", dev_repos), ("test", test_repos)]:
        with (output_dir / f"{key}_repos.txt").open("w") as f:
            f.write("\n".join(names))

    if not (output_dir / "vocab.model").exists():
        # Write out training text and train SentencePiece model.
        train_text_path = output_dir / "train_text.txt"
        with train_text_path.open("w") as f:
            for idx in tqdm(train_split, desc="Writing training text"):
                src_tokens = data[idx].decompiled_code.split(TOKEN_SEP)
                new_src_tokens = []
                for token in src_tokens:
                    if token in data[idx].var_names:
                        var1, var2 = data[idx].var_names[token]
                        new_src_tokens += [var1, var2]
                    else:
                        new_src_tokens.append(token)
                f.write(" ".join(new_src_tokens) + "\n")
                f.write(data[idx].original_code.replace(TOKEN_SEP, " ") + "\n")
        spm_train_args = {
            "input": train_text_path,
            "model_prefix": output_dir / "vocab",
            "vocab_size": args.vocab_size,
        }
        spm.SentencePieceTrainer.Train(" ".join(
            f"--{name}={str(value)}"
            for name, value in spm_train_args.items()))

    if args.encode_spm:
        # Encode all sentences with the trained SP model.
        with flutes.safe_pool(args.n_procs,
                              state_class=EncodeSPMState,
                              init_args=(output_dir /
                                         "vocab.model", )) as pool:
            processed_code = list(
                pool.imap(EncodeSPMState.encode_spm,
                          map(
                              lambda ex:
                              (ex.decompiled_code, ex.original_code),
                              tqdm(data, desc="Encoding with SPM")),
                          chunksize=args.block_size))

        write_files(output_dir / "tokenized", lambda idx: processed_code[idx])