Exemplo n.º 1
0
def build_language_ast(name: str, dirs: List[Path], pickle_path: Path,
                       data_params: DatasetParams):
    start = time.time()

    if data_params.use_ast == "tree-sitter":
        parser = TreeSitterParser(
            langs=["go", "java", "javascript", "python", "php", "ruby"],
            added_nodes=data_params.ast_added_nodes,
            skip_node_types=data_params.ast_skip_node_types,
        )

        all_special_tokens: Set[str] = set()

        lengths: Dict[str, List[int]] = {
            "go": [],
            "java": [],
            "javascript": [],
            "python": [],
            "php": [],
            "ruby": []
        }

        for (idx, file_path) in enumerate(get_data_files_from_directory(dirs)):
            logger.info(f"Reading {file_path}")
            raw_samples = list(read_file_samples(file_path))
            for raw_sample in raw_samples:
                lang = raw_sample["language"]
                tokens, special_tokens = parser.parse_full(
                    lang, raw_sample["code"])

                all_special_tokens.update(special_tokens)

                lengths[lang].append(len(tokens))

        end = time.time()
        logger.debug(
            f"all_special_tokens ({len(all_special_tokens)}) {all_special_tokens}"
        )

        if not os.path.exists(pickle_path):
            os.makedirs(pickle_path)

        json_file = Path(pickle_path) / f"{name}_special_tokens.json"
        with open(json_file, "w") as f:
            json.dump(list(all_special_tokens), f)

        import statistics

        for lang, lgs in lengths.items():
            if len(lgs) > 0:
                max_lg = max(lgs)
                min_lg = min(lgs)
                mean_lg = statistics.mean(lgs)
                std_lg = statistics.stdev(lgs)
                logger.debug(
                    f"{lang} [ min:{min_lg}, max:{max_lg}, mean:{mean_lg}, stddev:{std_lg} ]"
                )

        time_p = end - start
        logger.info(f"Building AST took: {time_p} sec")
Exemplo n.º 2
0
def build_huggingface_token_files(
    data_dirs: List[Path],
    data_params: DatasetParams,
    output_path: Union[Path, str],
    sample_update: Callable[[str, str, List[str]],
                            str] = default_sample_update,
) -> Tuple[List[Path], Dict[str, Path]]:
    tokenizers_path = Path(output_path)
    os.makedirs(tokenizers_path, exist_ok=True)
    # build files of strings
    lang_ios: Dict[str, Tuple[IO[str], IO[str]]] = {}

    query_files: List[Path] = []
    lang_files: Dict[str, Path] = {}
    for (idx,
         file_path) in enumerate(get_data_files_from_directory(data_dirs)):
        logger.info(f"Reading {file_path}")
        for raw_sample in read_file_samples(file_path):
            lang = raw_sample["language"]
            if lang not in lang_ios:
                query_file = tokenizers_path / f"{lang}_query.txt"
                code_file = tokenizers_path / f"{lang}_code.txt"
                lang_ios[lang] = (open(query_file, "w"), open(code_file, "w"))
                query_files.append(query_file)
                lang_files[lang] = code_file
            lang_ios[lang][0].write(
                sample_update("query", lang, raw_sample["docstring_tokens"]))
            lang_ios[lang][1].write(
                sample_update("code", lang, raw_sample["code_tokens"]))

    return query_files, lang_files
Exemplo n.º 3
0
def parse_data_file_siamese_tokenizer(
    data_file: Path, data_params: DatasetParams, tokenizer: TokenizerRecordable, lang_token: str, query_token: str
) -> Tuple[str, int, Samples]:
    logger.info(f"Reading samples from {data_file}")
    filename = os.path.basename(data_file)
    file_language = filename.split("_")[0]

    samples = list(read_file_samples(data_file))

    ds: List[Dict[str, Union[str, int]]] = []
    for raw_sample in samples:
        language = raw_sample["language"]
        if language.startswith("python"):  # In some datasets, we use 'python-2.7' and 'python-3'
            language = "python"

        if language != file_language:
            logger.error(f"file with different language {language} from filename {file_language}")
            sys.exit(f"file with multiple language {language} from filename {file_language}")

        # the load_data_from_sample method call places processed data into sample, and
        # returns a boolean flag indicating if sample should be used
        function_name = raw_sample.get("func_name")
        data_code = load_data_from_sample_siamese(
            language=language,
            encoder_label="code",
            data_to_load=raw_sample["code_tokens"],
            function_name=function_name,
            tokenizer=tokenizer,
            fraction_using_func_name=data_params.fraction_using_func_name,
            min_len_func_name_for_query=data_params.min_len_func_name_for_query,
            use_subtokens=data_params.use_subtokens,
            mark_subtoken_end=data_params.mark_subtoken_end,
            max_num_tokens=data_params.code_max_num_tokens,
            lang_token=lang_token,
            query_token=query_token,
        )

        # query doesn't use the language
        data_query = load_data_from_sample_siamese(
            language=language,
            encoder_label="query",
            data_to_load=[d.lower() for d in raw_sample["docstring_tokens"]],
            function_name=function_name,
            tokenizer=tokenizer,
            fraction_using_func_name=data_params.fraction_using_func_name,
            min_len_func_name_for_query=data_params.min_len_func_name_for_query,
            use_subtokens=data_params.use_subtokens,
            mark_subtoken_end=data_params.mark_subtoken_end,
            max_num_tokens=data_params.query_max_num_tokens,
            lang_token=lang_token,
            query_token=query_token,
        )

        if data_code is not None and data_query is not None:
            d = {"language": language, "similarity": 1, **data_code, **data_query}
            ds.append(d)

    logger.debug(f"Parsed file {data_file}: language {file_language} [{len(ds)} samples]")

    return (file_language, len(ds), ds)
Exemplo n.º 4
0
def parse_data_file_ast_tokenizer(
    data_file: Path,
    data_params: DatasetParams,
    tokenizer: TokenizerRecordable,
    ast_parser: TreeSitterParser,
    query_token: str,
    pickle_path: Path,
) -> Tuple[str, pd.DataFrame]:
    logger.info(f"Reading samples from {data_file}")
    filename = os.path.basename(data_file)
    file_language = filename.split("_")[0]
    file_id = filename.split(".")[0]
    pickle_file = pickle_path / f"{file_id}.p"

    if pickle_file.exists():
        df = pd.read_pickle(pickle_path / f"{file_id}.p")
        return (file_language, df)

    samples = list(read_file_samples(data_file))

    # ds: List[Dict[str, Union[str, int]]] = []
    codes: List[List[str]] = []
    funcs: List[List[str]] = []
    docstrings: List[List[str]] = []
    for idx, raw_sample in enumerate(tqdm(samples)):
        language = raw_sample["language"]
        if language.startswith(
                "python"
        ):  # In some datasets, we use 'python-2.7' and 'python-3'
            language = "python"

        if language != file_language:
            logger.error(
                f"file with different language {language} from filename {file_language}"
            )
            sys.exit(
                f"file with multiple language {language} from filename {file_language}"
            )

        function_name = raw_sample.get("func_name")

        code: List[str] = ast_parser.parse(
            language,
            raw_sample["code"],
            max_tokens=data_params.code_max_num_tokens)

        # Skip samples where the function name is very short, because it probably has too little information
        # to be a good search query.
        if (data_params.fraction_using_func_name > 0.0 and function_name and
                len(function_name) >= data_params.min_len_func_name_for_query):
            func = [query_token] + split_identifier_into_parts(function_name)
            code = [
                tokenizer.unk_token() if token == function_name else token
                for token in code
            ]
            docstring = [query_token] + [
                d.lower() for d in raw_sample["docstring_tokens"]
            ]

            codes.append(code)
            funcs.append(func)
            docstrings.append(docstring)

    code_toks: List[List[int]] = []
    code_masks: List[List[int]] = []
    func_toks: List[List[int]] = []
    func_masks: List[List[int]] = []
    docstring_toks: List[List[int]] = []
    docstring_masks: List[List[int]] = []

    for batch in batch_iter(codes, batch_size=100):
        toks, masks = tokenizer.encode_tokens(
            batch, max_length=data_params.code_max_num_tokens)
        code_toks.extend(toks)
        code_masks.extend(masks)

    for batch in batch_iter(funcs, batch_size=100):
        toks, masks = tokenizer.encode_tokens(
            batch, max_length=data_params.query_max_num_tokens)
        func_toks.extend(toks)
        func_masks.extend(masks)

    for batch in batch_iter(docstrings, batch_size=100):
        toks, masks = tokenizer.encode_tokens(
            batch, max_length=data_params.query_max_num_tokens)
        docstring_toks.extend(toks)
        docstring_masks.extend(masks)

    langs = [data_params.lang_ids[file_language]] * len(func_toks)
    similarities = [1] * len(func_toks)
    logger.debug(f"func_toks {func_toks[:2]}")
    logger.debug(f"docstring_toks {docstring_toks[:2]}")
    logger.debug(f"code_toks {code_toks[:2]}")
    logger.debug(f"langs {langs[:2]}")
    logger.debug(f"similarities {similarities[:2]}")
    df = pd.DataFrame({
        "lang": langs,
        "similarity": similarities,
        "func_tokens": func_toks,
        "func_masks": func_masks,
        "docstring_tokens": docstring_toks,
        "docstring_masks": docstring_masks,
        "code_tokens": code_toks,
        "code_masks": code_masks,
    })

    df.to_pickle(pickle_file)

    logger.debug(
        f"Saved file {data_file}: language {file_language} [{df.shape}] to {pickle_file}"
    )

    return (file_language, df)