예제 #1
0
def preprocess(args):
    # Process training set and use it to construct bpe
    bpe_state_dict = try_load(args.bpe_file)
    if bpe_state_dict is None:
        bpe = process_bpe_file(args.train_file, "train", args.max_tokens)
        save(args.bpe_file, bpe.state_dict(), "BPE")
    else:
        bpe = BPE()
        bpe.load_state_dict(bpe_state_dict)

    train_examples, train_eval = process_file(args.train_file, "train", bpe)
    train_meta = build_features(train_examples, "train",
                                args.train_record_file)
    save(args.train_eval_file, train_eval, message="train eval")
    save(args.train_meta_file, train_meta, message="train meta")

    # Process dev and test sets
    dev_examples, dev_eval = process_file(args.dev_file, "dev", bpe)
    dev_meta = build_features(dev_examples, "dev", args.dev_record_file)
    save(args.dev_eval_file, dev_eval, message="dev eval")
    save(args.dev_meta_file, dev_meta, message="dev meta")

    if args.include_test_examples:
        test_examples, test_eval = process_file(args.test_file, "test", bpe)
        test_meta = build_features(test_examples, "test",
                                   args.test_record_file)
        save(args.test_eval_file, test_eval, message="test eval")
        save(args.test_meta_file, test_meta, message="test meta")
예제 #2
0
def process_bpe_file(filename, data_type, max_length):
    print(f"Getting vocab from {data_type} examples...")
    lines = Counter()
    with open(filename, "r") as file:
        source = json.load(file)
        for article in tqdm(source["data"]):
            for paragraph in article["paragraphs"]:
                context = paragraph["context"]
                qas = paragraph["qas"]

                # Weight by number of questions it has
                lines[context] += len(qas)
                for qa in qas:
                    question = qa["question"]
                    lines[question] += 1

    bpe = BPE()
    bpe.build_base_vocab()
    bpe.build_vocab(lines)
    print("Learning bpe on {} words for {}".format(len(bpe.vocab), data_type))
    bpe.learn_bpe(max_length)
    return bpe
예제 #3
0
def get_bpe(args):
    bpe = BPE()
    with open(args.bpe_file, "r") as file:
        bpe.load_state_dict(json.load(file))
    add_special_tokens(args)
    return bpe