Exemplo n.º 1
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", "-f", required=True)

    args, unknown = parser.parse_known_args()

    data = read_data(args.file_path)

    directory, file_name = os.path.split(args.file_path)

    data_after_processing = []

    for line in data:

        line = "".join(c for c in line if c != "@")

        data_after_processing.append(line)

    idx = file_name.rfind(".")

    if idx == -1:
        new_file_name = file_name + "_remove_at"
    else:
        new_file_name = file_name[:idx] + "_remove_at" + file_name[idx:]

    write_data(data_after_processing, os.path.join(directory, new_file_name))
Exemplo n.º 2
0
def remove_long_sentence(args):

    assert args.max_sentence_length is not None

    for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        for src_sentence, tgt_sentence in zip(src_data, tgt_data):

            if len(src_sentence.split()) <= args.max_sentence_length and \
                    len(tgt_sentence.split()) <= args.max_sentence_length:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        write_data(src_sentence_filtered, filtered_src_file_path)
        write_data(tgt_sentence_filtered, filtered_tgt_file_path)
Exemplo n.º 3
0
def remove_long_sentence(src_file_path_list: list, tgt_file_path_list: list,
                         max_sentence_length: int):

    assert len(src_file_path_list) == len(tgt_file_path_list)

    for src_file_path, tgt_file_path in zip(src_file_path_list,
                                            tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        for src_sentence, tgt_sentence in zip(src_data, tgt_data):

            if len(src_sentence.split()) <= max_sentence_length and len(
                    tgt_sentence.split()) <= max_sentence_length:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        new_src_file_path = get_file_path(src_file_path)
        new_tgt_file_path = get_file_path(tgt_file_path)

        write_data(src_sentence_filtered, new_src_file_path)
        write_data(tgt_sentence_filtered, new_tgt_file_path)
Exemplo n.º 4
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument("--data_path", nargs="+")
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--strip_accents", action="store_true")
    parser.add_argument("--tokenize_chinese_chars", action="store_true")

    args, unknown = parser.parse_known_args()

    print("do_lower_case: {}, strip_accents: {}, tokenize_chinese_chars: {}".
          format(args.do_lower_case, args.strip_accents,
                 args.tokenize_chinese_chars))

    tokenizer = BasicTokenizer(
        do_lower_case=args.do_lower_case,
        strip_accents=args.strip_accents,
        tokenize_chinese_chars=args.tokenize_chinese_chars)

    for data_path in args.data_path:

        data = read_data(data_path)
        tok_data = []

        if data_path.endswith(".en"):

            tok_data = [tokenizer.tokenize(sentence) for sentence in data]

        else:
            for sentence in data:

                sentence = sentence.strip()

                first_blank_pos = sentence.find(" ")

                if first_blank_pos != -1:

                    # do not process language identify token
                    lang_identify_token = sentence[:first_blank_pos]

                    sentence_after_process = tokenizer.tokenize(
                        sentence[first_blank_pos + 1:])
                    sentence_after_process = " ".join(
                        [lang_identify_token, sentence_after_process])

                    tok_data.append(sentence_after_process)

                else:
                    tok_data.append(sentence)

        idx = data_path.rfind(".")

        assert idx != -1

        tok_data_path = data_path[:idx] + "_tok" + data_path[idx:]

        print("{} > {}".format(data_path, tok_data_path))

        write_data(tok_data, tok_data_path)
Exemplo n.º 5
0
def remove_blank(args):
    for file in args.zh_corpus_list:

        data = read_data(file)

        data = ["".join(sentence.strip().split()) for sentence in data]

        directory, file_name = os.path.split(file)
        idx = file_name.rfind(".")

        assert idx != -1

        new_file_name = file_name[:idx] + "_no_blank" + file_name[idx:]
        new_file = os.path.join(directory, new_file_name)
        write_data(data, new_file)
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--multi_gpu_translation_dir", required=True)
    parser.add_argument("--is_tok", action="store_true")
    parser.add_argument("--merged_translation_dir", required=True)

    args, unknown = parser.parse_known_args()

    translations_dict_per_model = {}

    for file in os.listdir(args.multi_gpu_translation_dir):

        file_name_prefix, extension = os.path.splitext(file)

        if args.is_tok and not file_name_prefix.endswith("_tok"):
            continue
        elif not args.is_tok and file_name_prefix.endswith("_tok"):
            continue

        assert extension[1:5] == "rank"
        rank = int(extension[5:])

        data = read_data(os.path.join(args.multi_gpu_translation_dir, file))

        if file_name_prefix in translations_dict_per_model:
            translations_dict_per_model[file_name_prefix].append((data, rank))
        else:
            translations_dict_per_model[file_name_prefix] = [(data, rank)]

    for file_name_prefix in translations_dict_per_model:
        translations_dict_per_model[file_name_prefix].sort(
            key=lambda item: item[1])

    if not os.path.isdir(args.merged_translation_dir):
        os.makedirs(args.merged_translation_dir)

    for file_name_prefix in translations_dict_per_model:

        merged_translations = []
        for translations, rank in translations_dict_per_model[
                file_name_prefix]:
            merged_translations.extend(translations)

        write_data(
            merged_translations,
            os.path.join(args.merged_translation_dir,
                         "{}.txt".format(file_name_prefix)))
Exemplo n.º 7
0
def remove_same_sentence(args):

    if args.src_memory_path:
        src_memory_list = []

        for memory_path in args.src_memory_path:
            src_memory_list.extend(read_data(memory_path))

        src_memory = set(src_memory_list)

    else:
        src_memory = set()

    for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        sentence_visited = set()
        removed_sentence_id = set()

        for i, sentence in enumerate(src_data):

            if sentence in sentence_visited:
                removed_sentence_id.add(i)
            elif sentence in src_memory:
                removed_sentence_id.add(i)
                sentence_visited.add(sentence)
            else:
                sentence_visited.add(sentence)

        for i, (src_sentence,
                tgt_sentence) in enumerate(zip(src_data, tgt_data)):

            if i not in removed_sentence_id:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        write_data(src_sentence_filtered, filtered_src_file_path)
        write_data(tgt_sentence_filtered, filtered_tgt_file_path)
Exemplo n.º 8
0
def sort_sentence(args):
    for src_file_path, tgt_file_path, sorted_src_file_path, sorted_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):
        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_data = [sentence.split() for sentence in src_data]
        tgt_data = [sentence.split() for sentence in tgt_data]

        src_data, tgt_data = sort_src_sentence_by_length(
            list(zip(src_data, tgt_data)))

        write_data(src_data, sorted_src_file_path)
        write_data(tgt_data, sorted_tgt_file_path)
Exemplo n.º 9
0
def tokenize(args):

    assert os.path.isdir(args.raw_corpus_dir)

    if not os.path.isdir(args.tokenized_corpus_dir):
        os.makedirs(args.tokenized_corpus_dir)

    tokenizer = sacrebleu.TOKENIZERS[sacrebleu.DEFAULT_TOKENIZER]

    for file in os.listdir(args.raw_corpus_dir):

        file_path = os.path.join(args.raw_corpus_dir, file)

        idx = file.rfind(".")
        assert idx != -1

        new_file_name = "{}.{}.{}".format(file[:idx], "tok", file[idx+1:])

        data = read_data(file_path)

        data_tok = [tokenizer(sentence) for sentence in data]

        new_file_path = os.path.join(args.tokenized_corpus_dir, new_file_name)
        write_data(data_tok, new_file_path)
Exemplo n.º 10
0
def merge(corpus_path: str):

    train_data_src = []
    train_data_tgt = []

    dev_data_src = []
    dev_data_tgt = []

    test_data_src = []
    test_data_tgt = []

    un_uesd_corpus = {"pt-br_en", "fr-ca_en", "eo_en", "calv_en"}

    for corpus_dir in os.listdir(corpus_path):

        if corpus_dir in un_uesd_corpus:
            continue

        corpus_dir = os.path.join(corpus_path, corpus_dir)

        for corpus_file_name in os.listdir(corpus_dir):

            corpus_file_path = os.path.join(corpus_dir, corpus_file_name)
            print(corpus_file_path)

            data = read_data(corpus_file_path)

            is_en = True

            if not corpus_file_name.endswith(".en"):

                idx = corpus_file_name.rfind(".")
                assert idx != -1

                idx += 1
                lang_identify_token = "".join(["<", corpus_file_name[idx:], ">"])
                data = [" ".join([lang_identify_token, sentence]) for sentence in data]
                is_en = False

            if corpus_file_name.startswith("train"):

                if is_en:
                    train_data_tgt.extend(data)
                else:
                    train_data_src.extend(data)

            elif corpus_file_name.startswith("dev"):

                if is_en:
                    dev_data_tgt.extend(data)
                else:
                    dev_data_src.extend(data)

            elif corpus_file_name.startswith("test"):

                if is_en:
                    test_data_tgt.extend(data)
                else:
                    test_data_src.extend(data)

    assert check(train_data_src, train_data_tgt) and check(dev_data_src, dev_data_tgt) and \
           check(test_data_src, test_data_tgt)

    output_dir = "/data/rrjin/NMT/data/ted_data/corpus"

    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    write_data(train_data_src, os.path.join(output_dir, "train_data_src.combine"))
    write_data(train_data_tgt, os.path.join(output_dir, "train_data_tgt.en"))

    write_data(dev_data_src, os.path.join(output_dir, "dev_data_src.combine"))
    write_data(dev_data_tgt, os.path.join(output_dir, "dev_data_tgt.en"))

    write_data(test_data_src, os.path.join(output_dir, "test_data_src.combine"))
    write_data(test_data_tgt, os.path.join(output_dir, "test_data_tgt.en"))
Exemplo n.º 11
0
        print("Time spend: {} seconds".format(end_time - start_time))

    if not os.path.exists(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    _, model_name = os.path.split(model_path)

    if args.beam_size:
        translation_file_name_prefix = "{}_beam_size{}".format(
            model_name, args.beam_size)
    else:
        translation_file_name_prefix = model_name
    p = os.path.join(args.translation_output_dir,
                     translation_file_name_prefix + "_translations.txt")

    write_data(pred_data, p)

    if args.need_tok:

        # replace '@@ ' with ''

        p_tok = os.path.join(
            args.translation_output_dir,
            translation_file_name_prefix + "_translations_tok.txt")

        tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(p, p_tok)

        call(tok_command, shell=True)

        bleu_calculation_command = "perl {} {} < {}".format(
            args.bleu_script_path, args.test_tgt_path, p_tok)
Exemplo n.º 12
0
    def write_shuffled_data(self, shuffled_src_file_path,
                            shuffled_tgt_file_path):

        write_data(self.shuffled_src_data, shuffled_src_file_path)
        write_data(self.shuffled_tgt_data, shuffled_tgt_file_path)
Exemplo n.º 13
0
def evaluation(local_rank, args):
    rank = args.nr * args.gpus + local_rank
    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(device)

    # List[str]
    src_data = read_data(args.test_src_path)

    tgt_prefix_data = None
    if args.tgt_prefix_file_path is not None:
        tgt_prefix_data = read_data(args.tgt_prefix_file_path)

    max_src_len = max(len(line.split()) for line in src_data) + 2
    max_tgt_len = max_src_len * 3
    logging.info("max src sentence length: {}".format(max_src_len))

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    src_data = convert_data_to_index(src_data, src_vocab)

    dataset = DataPartition(src_data, args.world_size, tgt_prefix_data,
                            args.work_load_per_process).dataset(rank)

    logging.info("dataset size: {}, rank: {}".format(len(dataset), rank))

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=(args.batch_size if args.batch_size else 1),
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        collate_fn=lambda batch: collate_eval(
            batch,
            padding_value,
            batch_first=(True if args.transformer else False)))

    if not os.path.isdir(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    if args.beam_size:
        logging.info("Beam size: {}".format(args.beam_size))

    if args.is_prefix:
        args.model_load = args.model_load + "*"

    for model_path in glob.glob(args.model_load):
        logging.info("Load model from: {}, rank: {}".format(model_path, rank))

        if args.transformer:

            s2s = load_transformer(model_path,
                                   len(src_vocab),
                                   max_src_len,
                                   len(tgt_vocab),
                                   max_tgt_len,
                                   padding_value,
                                   training=False,
                                   share_dec_pro_emb=args.share_dec_pro_emb,
                                   device=device)

        else:
            s2s = load_model(model_path, device=device)

        s2s.eval()

        if args.record_time:
            import time
            start_time = time.time()

        pred_data = []

        for data, tgt_prefix_batch in data_loader:
            if args.beam_size:
                pred_data.append(
                    beam_search_decoding(s2s, data.to(device,
                                                      non_blocking=True),
                                         tgt_vocab, args.beam_size, device))
            else:
                pred_data.extend(
                    greedy_decoding(s2s, data.to(device, non_blocking=True),
                                    tgt_vocab, device, tgt_prefix_batch))

        if args.record_time:
            end_time = time.time()
            logging.info("Time spend: {} seconds, rank: {}".format(
                end_time - start_time, rank))

        _, model_name = os.path.split(model_path)

        if args.beam_size:
            translation_file_name_prefix = "{}_beam_size_{}".format(
                model_name, args.beam_size)
        else:
            translation_file_name_prefix = "{}_greedy".format(model_name)

        p = os.path.join(
            args.translation_output_dir,
            "{}_translations.rank{}".format(translation_file_name_prefix,
                                            rank))

        write_data(pred_data, p)

        if args.need_tok:

            # replace '@@ ' with ''
            p_tok = os.path.join(
                args.translation_output_dir,
                "{}_translations_tok.rank{}".format(
                    translation_file_name_prefix, rank))

            tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
                p, p_tok)

            call(tok_command, shell=True)
Exemplo n.º 14
0
                test_tgt.extend(data)
            else:
                test_src.extend(data)

        else:
            raise Exception("Error!")

assert len(train_src) == len(train_tgt)
assert len(dev_src) == len(dev_tgt)
assert len(test_src) == len(test_tgt)


def remove_blank(src_data):

    src_data = ["".join(sentence.strip().split()) for sentence in src_data]
    return src_data


train_src = remove_blank(train_src)
dev_src = remove_blank(dev_src)
test_src = remove_blank(test_src)

write_data(train_src, os.path.join(output_dir, "train.zh"))
write_data(train_tgt, os.path.join(output_dir, "train.en"))

write_data(dev_src, os.path.join(output_dir, "dev.zh"))
write_data(dev_tgt, os.path.join(output_dir, "dev.en"))

write_data(test_src, os.path.join(output_dir, "test.zh"))
write_data(test_tgt, os.path.join(output_dir, "test.en"))
Exemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--device", required=True)
    parser.add_argument("--load", required=True)
    parser.add_argument("--src_vocab_path", required=True)
    parser.add_argument("--tgt_vocab_path", required=True)
    parser.add_argument("--test_src_path", required=True)
    parser.add_argument("--test_tgt_path", required=True)
    parser.add_argument("--lang_vec_path", required=True)
    parser.add_argument("--translation_output", required=True)
    parser.add_argument("--transformer", action="store_true")

    args, unknown = parser.parse_known_args()

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    src_data = read_data(args.test_src_path)

    lang_vec = load_lang_vec(args.lang_vec_path)  # lang_vec: dict

    device = args.device

    print("load from {}".format(args.load))

    if args.transformer:

        max_src_length = max(len(line) for line in src_data) + 2
        max_tgt_length = max_src_length * 3
        padding_value = src_vocab.get_index(src_vocab.mask_token)
        assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

        s2s = load_transformer(args.load,
                               len(src_vocab),
                               max_src_length,
                               len(tgt_vocab),
                               max_tgt_length,
                               padding_value,
                               device=device)

    else:
        s2s = load_model(args.load, device=device)

    s2s.eval()

    pred_data = []
    for i, line in enumerate(src_data):
        pred_data.append(
            translate(line, i, s2s, src_vocab, tgt_vocab, lang_vec, device))

    write_data(pred_data, args.translation_output)

    pred_data_tok_path = args.translation_output + ".tok"

    tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
        args.translation_output, pred_data_tok_path)
    call(tok_command, shell=True)

    bleu_calculation_command = "perl /data/rrjin/NMT/scripts/multi-bleu.perl {} < {}".format(
        args.test_tgt_path, pred_data_tok_path)
    call(bleu_calculation_command, shell=True)