Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(description="Generate datapoints from AST")
    parser.add_argument("--ast_fp", "-a", help="Filepath with the ASTs to be parsed")
    parser.add_argument(
        "--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp"
    )
    parser.add_argument(
        "--out_fp", "-o", default="/tmp/dps.txt", help="Filepath with the output dps"
    )

    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)
    logging.info("Writing dps to: {}".format(args.out_fp))

    num_dps = 0
    with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            aug_tokens, aug_leaf_ids = get_dps(dp, args.n_ctx)
            for (tokens, ext), leaf in zip(aug_tokens, aug_leaf_ids):
                if len(tokens) > 1:
                    json.dump([tokens, ext], fp=fout)
                    fout.write("\n")
                    num_dps += 1

    logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(
        description="Parse AST IDs for evaluation")
    parser.add_argument("--ast", help="Filepath with new ASTs")
    parser.add_argument("--out", help="Outfile for ids.txt")
    parser.add_argument("--tokenizer", help="Specify Tokenizer")

    args = parser.parse_args()
    if os.path.exists(args.out):
        os.remove(args.out)

    tokenizer = Tokenizer.from_file(args.tokenizer)

    num_dps = 0
    with open(args.ast) as fin, open(args.out, "w") as fout, open(
            "output/debug_ids.txt", "w") as fout2:
        for i, line in enumerate(file_tqdm(fin)):
            dp = json.loads(line.strip())
            asts = split(dp, 1000, tokenizer)
            fout2.write("{}: {}\n".format(i, len(asts)))
            for ast in asts:
                if len(ast) > 1:
                    ids = {"leaf_ids": ast}
                    json.dump(ids, fp=fout)
                    fout.write("\n")
                    num_dps += 1
    logging.info("Wrote {} datapoints to {}".format(num_dps, args.out))
def main():
    parser = argparse.ArgumentParser(description="Generate datapoints from AST")
    parser.add_argument("--ast_fp", "-a", help="Filepath with the ASTs to be parsed")
    parser.add_argument(
        "--out_fp", "-o", default="/tmp/dps.txt", help="Filepath for the output dps"
    )
    parser.add_argument(
        "--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp"
    )
    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)
    logging.info("Number of context: {}".format(args.n_ctx))

    num_dps = 0
    logging.info("Loading asts from: {}".format(args.ast_fp))
    with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            asts = separate_dps(dp, args.n_ctx)
            for ast, extended in asts:
                if len(ast) > 1:
                    json.dump([get_dfs(ast), extended], fp=fout)
                    fout.write("\n")
                    num_dps += 1

    logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
Beispiel #4
0
def build_subword_prob_cli(args):
    if os.path.exists(args.output):
        logger.warning(f"{args.output} already exists!")

    logger.info("loading...")
    with open(args.word_freq) as fin:
        word_count_iter = (json.loads(line) for line in file_tqdm(fin))
        subword_counter = build_subword_counter(
            word_count_iter,
            min_count=args.subword_min_count,
            min_len=args.subword_min_len,
            max_len=args.subword_max_len,
            word_boundary=args.word_boundary,
            uniq_factor=args.subword_uniq_factor,
        )
    logger.info("processing...")
    if args.subword_prob_take_root:
        logger.warning(
            "`args.subword_prob_take_root = True` ignored at this step.")
    subword_prob = build_subword_prob(
        subword_counter,
        normalize_prob=normalize_prob,
        min_prob=args.subword_prob_min_prob,
        # take_root=args.subword_prob_take_root,
    )
    logger.info("saving...")
    with open(args.output, 'w') as fout:
        for (subword, prob) in tqdm(subword_prob.most_common()):
            print(json.dumps((subword, prob)), file=fout)
def external(file_path, n_vocab):
    outfile = "output/vocab.pkl"
    logging.info("Reading from: {}".format(file_path))
    vocab = Counter()
    with open(file_path, "r") as f:
        for line in file_tqdm(f):
            vocab.update(get_value(json.loads(line.strip()), "ast"))
    vocab_to_keep = [i[0] for i in vocab.most_common(n_vocab)]
    top_total = sum(i[1] for i in vocab.most_common(n_vocab))
    total = sum(vocab.values())

    logging.info("Total # of vocab: {}".format(len(vocab)))
    logging.info(
        "Using {} top vocab covers: {:.2f}% of the entire dataset".format(
            n_vocab, 100 * top_total / total))
    logging.info("Top 10 most common vocab:")
    for v, i in vocab.most_common(10):
        print(v, i)

    # add unk and pad tokens
    vocab_to_keep.append(UNK)
    vocab_to_keep.append(PAD)
    logging.info("Added {} and {}".format(UNK, PAD))

    # dump vocab to file
    with open(outfile, "wb") as fout:
        pickle.dump(vocab_to_keep, fout)
    logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), outfile))
Beispiel #6
0
def external(file_path, suffix, n_ctx):
    outfile = "output/{}_ids.txt".format(suffix)

    if os.path.exists(outfile):
        os.remove(outfile)
    logging.info("Type of id to get: {}".format("all"))

    logging.info("Loading dps from: {}".format(file_path))
    with open(file_path, "r") as f, open(outfile, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            asts = separate_dps(dp, n_ctx)
            for ast, _ in asts:
                ids = {}
                if len(ast) > 1:
                    if "all" in {"leaf", "all"}:
                        ids.update(get_leaf_ids(ast))
                    if "all" in {"value", "all"}:
                        ids.update(get_value_ids(ast))
                    if "all" in {"type", "all"}:
                        ids.update(get_type_ids(ast))

                    json.dump(ids, fp=fout) 
                    fout.write("\n")
    logging.info("Wrote to: {}".format(outfile))
def external(file_path, suffix):
    outfile = "output/{}_new_trees.json".format(suffix)
    if os.path.exists(outfile):
        os.remove(outfile)
    logging.info("Loading asts from: {}".format(file_path))
    with open(file_path, "r") as f, open(outfile, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            print(json.dumps(convert(dp)), file=fout)
    logging.info("Wrote dps to: {}".format(outfile))
def main():
    parser = argparse.ArgumentParser(
        description="Generate datapoints from source code")
    parser.add_argument("--files_fp",
                        "-f",
                        help="Filepath with the filenames to be parsed")
    parser.add_argument("--out_fp",
                        "-o",
                        default="/tmp/dps.txt",
                        help="Filepath with the output dps")
    parser.add_argument("--base_dir",
                        "-b",
                        help="Base dir to append for the fps")
    parser.add_argument("--n_ctx",
                        "-c",
                        type=int,
                        default=1000,
                        help="Number of contexts for each dp")
    parser.add_argument(
        "id_type",
        choices=["leaf", "value", "token", "all"],
        default="",
        help="Which ids to generate. Default = get the tokens",
    )
    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)
    logging.info("Number of context: {}".format(args.n_ctx))

    num_dps = 0
    logging.info("Loading files from: {}".format(args.base_dir))
    with open(args.files_fp, "r",
              errors='ignore') as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            fp = os.path.join(args.base_dir, line.strip())
            try:
                aug_tokens, aug_types = my_tokenize(
                    open(fp).read(), args.n_ctx)
                for (tokens, ext), (types_, _) in zip(aug_tokens, aug_types):
                    if len(tokens) > 1:
                        if args.id_type == "leaf":
                            json.dump(get_leaf_ids(types_), fp=fout)
                        elif args.id_type == "value":
                            json.dump(get_value_ids(types_), fp=fout)
                        elif args.id_type == "all":
                            ids = get_leaf_ids(types_)
                            ids.update(get_value_ids(types_))
                            json.dump(ids, fp=fout)
                        else:
                            json.dump([tokens, ext], fp=fout)
                        fout.write("\n")
                        num_dps += 1
            except:
                continue
    logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
Beispiel #9
0
def external(fp, suffix, tokenizer):
    tokenizer = Tokenizer.from_file(tokenizer)
    outfile = "output/{}_ids.txt".format(suffix)
    num_dps = 0
    with open(fp) as fin, open(outfile, "w") as fout:
        for i, line in enumerate(file_tqdm(fin)):
            dp = json.loads(line.strip())
            asts = split(dp, 1000, tokenizer)
            for ast in asts:
                if len(ast) > 1:
                    ids = {"leaf_ids": ast}
                    json.dump(ids, fp=fout)
                    fout.write("\n")
                    num_dps += 1
    logging.info("Wrote {} datapoints to {}".format(num_dps, outfile))
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser(description="Create vocab for py150 dataset")
    parser.add_argument("--n_vocab", "-n", type=int, default=100000)
    parser.add_argument("--input_fp", "-i")
    parser.add_argument("--out_fp", "-o", default="/tmp/vocab.pkl")
    parser.add_argument(
        "--input_type",
        "-t",
        choices=["ast", "leaf", "source_code"],
        help="Where to get the input from (all AST nodes, leaf nodes, or source code",
    )
    args = parser.parse_args()

    logging.info("Reading from: {}".format(args.input_fp))
    logging.info("Input type: {}".format(args.input_type))
    vocab = Counter()
    # with open('../data/test_trees.json', "r") as f:
    #     for line in file_tqdm(f):
    #         vocab.update(get_value(json.loads(line.strip()), args.input_type))
    with open(args.input_fp, "r") as f:
        for line in file_tqdm(f):
            vocab.update(get_value(json.loads(line.strip()), args.input_type))
    vocab_to_keep = [i[0] for i in vocab.most_common(args.n_vocab)]
    top_total = sum(i[1] for i in vocab.most_common(args.n_vocab))
    total = sum(vocab.values())

    logging.info("Total # of vocab: {}".format(len(vocab)))
    logging.info(
        "Using {} top vocab covers: {:.2f}% of the entire dataset".format(
            args.n_vocab, 100 * top_total / total
        )
    )
    logging.info("Top 10 most common vocab:")
    for v, i in vocab.most_common(10):
        print(v, i)

    # add unk and pad tokens
    vocab_to_keep.append(UNK)
    vocab_to_keep.append(PAD)
    logging.info("Added {} and {}".format(UNK, PAD))

    # dump vocab to file
    with open(args.out_fp, "w") as fout:
        json.dump(vocab_to_keep, fout)
    logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), args.out_fp))
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser(
        description="Generate ids (leaf, values, types) from AST")
    parser.add_argument("--ast_fp",
                        "-a",
                        help="Filepath with the new ASTs to be parsed")
    parser.add_argument("--out_fp",
                        "-o",
                        default="/tmp/ids.txt",
                        help="Filepath for the output ids")
    parser.add_argument("--n_ctx",
                        "-c",
                        type=int,
                        default=1000,
                        help="Number of contexts for each dp")
    parser.add_argument(
        "id_type",
        choices=["leaf", "value", "type", "all"],
        default="leaf",
        help="Which ids to generate. Default = leaf",
    )

    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)
    logging.info("Type of id to get: {}".format(args.id_type))

    logging.info("Loading dps from: {}".format(args.ast_fp))
    with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            asts = separate_dps(dp, args.n_ctx)
            for ast, _ in asts:
                ids = {}
                if len(ast) > 1:
                    if args.id_type in {"leaf", "all"}:
                        ids.update(get_leaf_ids(ast))
                    if args.id_type in {"value", "all"}:
                        ids.update(get_value_ids(ast))
                    if args.id_type in {"type", "all"}:
                        ids.update(get_type_ids(ast))

                    json.dump(ids, fp=fout)
                    fout.write("\n")
    logging.info("Wrote to: {}".format(args.out_fp))
def main():
    parser = argparse.ArgumentParser(
        description="Create vocab for code2seq model for py150 dataset")
    parser.add_argument("--n_vocab", "-n", type=int, default=100000)
    parser.add_argument("--input_fp", "-i")
    parser.add_argument("--out_fp", "-o", default="/tmp/vocab.pkl")
    parser.add_argument(
        "--vocab_type",
        "-v",
        choices=["token", "subtoken", "output"],
        help="What type of vocab to get",
    )
    args = parser.parse_args()

    logging.info("Reading from: {}".format(args.input_fp))
    logging.info("Vocab type: {}".format(args.vocab_type))
    vocab = Counter()
    with open(args.input_fp, "r") as f:
        for line in file_tqdm(f):
            vocab.update(get_value(json.loads(line.strip()), args.vocab_type))
    vocab_to_keep = [i[0] for i in vocab.most_common(args.n_vocab)]
    top_total = sum(i[1] for i in vocab.most_common(args.n_vocab))
    total = sum(vocab.values())

    logging.info("Total # of vocab: {}".format(len(vocab)))
    logging.info(
        "Using {} top vocab covers: {:.2f}% of the entire dataset".format(
            args.n_vocab, 100 * top_total / total))
    logging.info("Top 10 most common vocab:")
    for v, i in vocab.most_common(10):
        print(v, i)

    # add unk and pad tokens
    vocab_to_keep.append(UNK)
    vocab_to_keep.append(PAD)
    vocab_to_keep.append(PLACEHOLDER)
    logging.info("Added {} and {} and {}".format(UNK, PAD, PLACEHOLDER))

    # dump vocab to file
    with open(args.out_fp, "wb") as fout:
        pickle.dump(vocab_to_keep, fout)
    logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep),
                                                args.out_fp))
    def __init__(
        self, base_dir, fp, ids_fp, max_vocab=100000, mode="train"
    ):
        super().__init__()
        if mode not in {"train", "test"}:
            raise Exception("Mode must be either train or test")
        self.mode = mode
        self.fp = fp
        self.max_vocab = max_vocab

        # get all the relevant filepaths
        self.filepaths = {
            "vocab": os.path.join(base_dir, "vocab.pkl"),
            "metrics": os.path.join(base_dir, "{}_metrics.txt".format(mode)),
            "conv": os.path.join(base_dir, "{}_converted.txt".format(mode)),
        }
        self._add_extra_filepaths(base_dir)

        logging.info("Writing metrics to: {}".format(self.filepaths["metrics"]))

        # filter dataset
        filtered_fp = self._filter_dataset()

        # set up vocab
        self.vocab = self._create_vocab()

        # convert
        if not os.path.exists(self.filepaths["conv"]):
            with open(filtered_fp, "r") as fin, open(
                self.filepaths["conv"], "w"
            ) as fout:
                for line in utils.file_tqdm(fin):
                    line = json.loads(line.strip())
                    print(json.dumps(self.vocab.convert(line)), file=fout)
            logging.info(
                "Converted dataset to idx and saved to: {}".format(
                    self.filepaths["conv"]
                )
            )

        # return dataset
        self.dataset = self._create_dataset(self.filepaths["conv"], ids_fp)
        logging.info("Loaded dataset from {}".format(self.filepaths["conv"]))
def external(file_path, suffix, context_size, overlap):
    outfile = "output/{}_dps.txt".format(suffix)
    if os.path.exists(outfile):
        os.remove(outfile)
    logging.info("Number of context: {}".format(context_size))

    num_dps = 0
    logging.info("Loading asts from: {}".format(file_path))
    with open(file_path, "r") as f, open(outfile, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            asts = rq6_separate_dps(dp, context_size, overlap)
            for ast, extended in asts:
                if len(ast) > 1:
                    json.dump([get_dfs(ast), extended], fp=fout)
                    fout.write("\n")
                    num_dps += 1

    logging.info("Wrote {} datapoints to {}".format(num_dps, outfile))
Beispiel #15
0
def preprocess(fp, suffix, tokenizer):
    tokenizer = Tokenizer.from_file(tokenizer)
    dps_outfile = "output/{}_dps.txt".format(suffix)
    ids_outfile = "output/{}_ids.txt".format(suffix)
    num = 0
    with open(fp) as fin, open(dps_outfile,
                               "w") as fout_dps, open(ids_outfile,
                                                      "w") as fout_ids:
        for i, line in enumerate(file_tqdm(fin)):
            dp = json.loads(line.strip())
            asts, ids = split(dp, 1000, tokenizer)
            for i, (ast, extended) in enumerate(asts):
                if len(ast) > 1:
                    json.dump([ast, extended], fp=fout_dps)
                    json.dump(ids[i], fp=fout_ids)
                    fout_dps.write("\n")
                    fout_ids.write("\n")
                    num += 1
    logging.info("Wrote {} datapoints to {} and {}".format(
        num, ids_outfile, dps_outfile))
def main():
    parser = argparse.ArgumentParser(
        description="Generate datapoints from AST")
    parser.add_argument("--ast_fp",
                        "-a",
                        help="Filepath with the ASTs to be parsed")
    parser.add_argument("--out_fp",
                        "-o",
                        default="/tmp/dps.txt",
                        help="Filepath with the output dps")
    parser.add_argument("--n_ctx",
                        "-c",
                        type=int,
                        default=1000,
                        help="Number of contexts for each dp")
    parser.add_argument(
        "--max_path_len",
        "-p",
        type=int,
        default=13,
        help="Max length of rootpath route",
    )

    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)
    logging.info("Writing dps to: {}".format(args.out_fp))

    num_dps = 0
    with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            for dp in get_dps(dp, args.n_ctx, args.max_path_len):
                if len(dp[0]) > 1:
                    json.dump(dp, fout)
                    fout.write("\n")
                    num_dps += 1

    logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
Beispiel #17
0
def build_subword_vocab_cli(args):
    if os.path.exists(args.output):
        logger.warning(f"{args.output} already exists!")

    logger.info("loading...")
    with open(args.word_freq) as fin:
        word_count_iter = (json.loads(line) for line in file_tqdm(fin))
        subword_counter = build_subword_counter(
            word_count_iter,
            max_size=args.subword_vocab_max_size,
            min_count=args.subword_min_count,
            min_len=args.subword_min_len,
            max_len=args.subword_max_len,
            word_boundary=args.word_boundary,
            uniq_factor=args.subword_uniq_factor,
        )
    logger.info("processing...")
    subword_vocab = subword_counter
    logger.info("saving...")
    with open(args.output, 'w') as fout:
        for (subword, count) in tqdm(subword_vocab.most_common()):
            print(json.dumps((subword, count)), file=fout)
Beispiel #18
0
def load_embedding(filename: str,
                   show_progress=False) -> (List[str], np.ndarray):
    """
    :param filename: a .txt file or a .pkl/.pickle file
    :return: tuple (words, embeddings)
    """
    import os
    if show_progress:
        from utils import file_tqdm
    else:
        from utils import dummy_tqdm as file_tqdm

    _, ext = os.path.splitext(filename)
    if ext in (".txt", ".w2v"):
        vocab, emb = [], []
        with open(filename, "r") as fin:
            if ext == ".w2v":
                next(fin)
            for line in file_tqdm(fin):
                ss = line.split()
                try:
                    emb.append([float(x) for x in ss[1:]])
                    vocab.append(ss[0])
                except ValueError:
                    print(f"Error loading the line: {line[:30]} ...")
        emb = np.array(emb)
    elif ext in (".pickle", ".pkl"):
        import pickle
        try:
            with open(filename, 'rb') as bfin:
                vocab, emb = pickle.load(bfin)
        except UnicodeDecodeError:
            with open(filename, 'rb') as bfin:
                vocab, emb = pickle.load(bfin, encoding='bytes')
    else:
        raise ValueError(f'Unsupported target vector file extent: {filename}')

    return vocab, emb
def main():
    parser = argparse.ArgumentParser(
        description="Generate datapoints from AST")
    parser.add_argument("--input_fp",
                        "-i",
                        help="Filepath with the ASTs to be parsed")
    parser.add_argument(
        "--out_fp",
        "-o",
        default="/tmp/new_trees.json",
        help="Filepath with the output dps",
    )

    args = parser.parse_args()
    if os.path.exists(args.out_fp):
        os.remove(args.out_fp)

    logging.info("Loading asts from: {}".format(args.input_fp))
    with open(args.input_fp, "r") as f, open(args.out_fp, "w") as fout:
        for line in file_tqdm(f):
            dp = json.loads(line.strip())
            print(json.dumps(convert(dp)), file=fout)
    logging.info("Wrote dps to: {}".format(args.out_fp))
Beispiel #20
0
add_subword_vocab_args(parser)
add_logging_args(parser)
args = parser.parse_args()

set_logging_config(args)
dump_args(args)

logger.info(f"building subword prob from `{args.prob_word_freq}`...")
if args.prob_word_freq.lower().startswith("unigram_freq"):
    word_freq_path = import_module("datasets.unigram_freq")\
        .prepare_unigram_freq_paths().word_freq_path
else:
    raise ValueError(
        f"args.prob_word_freq=`{args.prob_word_freq}` not supported.")
with open(word_freq_path) as fin:
    word_count_iter = (json.loads(line) for line in file_tqdm(fin))
    subword_counter = build_subword_counter(
        word_count_iter,
        min_count=args.subword_min_count,
        min_len=args.subword_min_len,
        max_len=args.subword_max_len,
        word_boundary=args.word_boundary,
        uniq_factor=args.subword_uniq_factor,
    )
subword_prob = build_subword_prob(
    subword_counter,
    normalize_prob=normalize_prob,
    min_prob=args.subword_prob_min_prob,
    take_root=args.subword_prob_take_root,
)
logger.info(f"subword prob size: {len(subword_prob)}")
 def load_from_file(self, file_path):
     with open(file_path) as fin:
         for line in utils.file_tqdm(fin):
             self.inputs.append([int(s) for s in line.split()])
Beispiel #22
0
def prepare_glove_paths(
    dir_path=dir_path,
    zip_path=zip_path,
    raw_emb_path=raw_emb_path,
    txt_emb_path=txt_emb_path,
    word_freq_path=word_freq_path,
    w2v_emb_path=w2v_emb_path,
    raw_count_path=raw_count_path,
):
    if not os.path.exists(zip_path):
        logger.info("downloading zip file...")
        url = "http://nlp.stanford.edu/data/glove.840B.300d.zip"
        sp.run(f"wget -O {zip_path} {url}".split())

    if not os.path.exists(raw_emb_path):
        logger.info("unzipping...")
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(dir_path)

    if not os.path.exists(txt_emb_path):
        logger.info("generating txt emb file...")
        with open(raw_emb_path, "r") as fin, open(txt_emb_path, "w") as fout:
            vocab_len = 0
            for line in file_tqdm(fin):
                ss = line.split()
                if len(ss) != emb_dim + 1:
                    logging.critical(
                        f'line "{line[:30]}"... might include word with space, skipped'
                    )
                    continue

                w = ss[0]

                # copied from `datasets/google/converter.py`
                aw = unicodedata.normalize("NFKD", w).encode("ASCII", "ignore")
                if 20 > len(aw) > 1 and not any(
                        c in w for c in " _./") and aw.islower():
                    vocab_len += 1
                    fout.write(line)

    if not os.path.exists(w2v_emb_path):
        logger.info("generating w2v emb file...")
        with open(txt_emb_path) as fin, open(w2v_emb_path, "w") as fout:
            print(vocab_len, emb_dim, file=fout)
            for line in file_tqdm(fin):
                fout.write(line)

    if not os.path.exists(word_freq_path):
        logger.info("generating word freq jsonl file...")
        with open(txt_emb_path) as fin, open(word_freq_path, "w") as fout:
            for line in fin:
                print(json.dumps((line.split()[0], 1)), file=fout)

    if not os.path.exists(raw_count_path):
        logger.info("generating word freq txt file...")
        with open(txt_emb_path) as fin, open(raw_count_path, "w") as fout:
            for line in fin:
                print(line.split()[0], 1, file=fout, sep='\t')

    return dotdict(
        dir_path=dir_path,
        raw_emb_path=raw_emb_path,
        txt_emb_path=txt_emb_path,
        w2v_emb_path=w2v_emb_path,
        word_freq_path=word_freq_path,
        raw_count_path=raw_count_path,
    )
Beispiel #23
0
def main(args):
    set_logging_config(args)

    save_path = args.model_path.format(
        timestamp=datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    save_dir, _ = os.path.split(save_path)
    try:
        os.makedirs(save_dir)
    except FileExistsError:
        logger.warning(
            "Things will get overwritten for directory {}".format(save_dir))

    dump_args(args, logger, os.path.join(save_dir, 'args.json'))

    logger.info(f'loading target vectors from `{args.target_vectors}`...')
    target_words, target_embs = \
        load_embedding(args.target_vectors, show_progress=True)
    logger.info(f'embeddings loaded with {len(target_words)} words')

    logger.info(f"loading subword vocab from `{args.subword_vocab}`...")
    with open(args.subword_vocab) as fin:
        subword_vocab = dict(json.loads(line) for line in file_tqdm(fin))
    logger.info(f"subword vocab size: {len(subword_vocab)}")

    if args.subword_prob:
        logger.info(f"loading subword prob from `{args.subword_prob}`...")
        with open(args.subword_prob) as fin:
            subword_prob = dict(json.loads(line) for line in file_tqdm(fin))
        subword_prob = subword_prob_post_process(
            subword_prob,
            min_prob=args.subword_prob_min_prob,
            # take_root=args.subword_prob_take_root,
        )
    else:
        subword_prob = None

    np.random.seed(args.random_seed)

    def MSE(pred, target):
        return sum((pred - target)**2) / 2

    def MSE_backward(pred, target):
        return (pred - target)

    model = PBoS(
        embedding_dim=len(target_embs[0]),
        subword_vocab=subword_vocab,
        subword_prob=subword_prob,
        weight_threshold=args.subword_weight_threshold,
        eps=args.subword_prob_eps,
        take_root=args.subword_prob_take_root,
        normalize_semb=args.normalize_semb,
    )
    start_time = time()
    for i_epoch in range(args.epochs):
        h = []
        h_epoch = []
        lr = args.lr / (1 + i_epoch)**0.5 if args.lr_decay else args.lr
        logger.info('epoch {:>2} / {} | lr {:.5f}'.format(
            1 + i_epoch, args.epochs, lr))
        epoch_start_time = time()
        for i_inst, wi in enumerate(
                np.random.choice(len(target_words),
                                 len(target_words),
                                 replace=False),
                start=1,
        ):
            target_emb = target_embs[wi]
            word = target_words[wi]
            model_word = bound_word(word) if args.word_boundary else word
            model_emb = model.embed(model_word)
            grad = MSE_backward(model_emb, target_emb)

            if i_inst % 20 == 0:
                loss = MSE(model_emb, target_emb) / len(
                    target_emb)  # average over dimension for easy reading
                h.append(loss)
            if i_inst % 10000 == 0:
                width = len(f"{len(target_words)}")
                fmt = 'processed {:%d}/{:%d} | loss {:.5f}' % (width, width)
                logger.info(
                    fmt.format(i_inst, len(target_words), np.average(h)))
                h_epoch.extend(h)
                h = []

            d = -lr * grad
            model.step(model_word, d)
        now_time = time()
        logger.info(
            'epoch {i_epoch:>2} / {n_epoch} | loss {loss:.5f} | time {epoch_time:.2f}s / {training_time:.2f}s'
            .format(
                i_epoch=1 + i_epoch,
                n_epoch=args.epochs,
                loss=np.average(h_epoch),
                epoch_time=now_time - epoch_start_time,
                training_time=now_time - start_time,
            ))

    logger.info('saving model...')
    model.dump(save_path)
Beispiel #24
0
import json
import utils

# This will escape special characters and remove whitespaces from node types/values, output is base for tokenizer
# TODO: Add ext info and store as json

with open("output/dps.txt", "r") as fin, open("output/new_ast_raw.txt",
                                              "w") as fout:
    for line in utils.file_tqdm(fin):
        json_line = json.loads(line)
        nodes = []
        for node in json_line[0]:
            nodes.append(node.strip().replace(
                " ", "<spc>").encode("unicode_escape").decode())
        fout.write(" ".join(nodes) + "\n")