Ejemplo n.º 1
0
def main(args):
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(filename=os.path.join(args.destdir,
                                                  "preprocess.log"), ))
    logger.info(args)

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    target = not args.only_source

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert (
            not args.srcdict or not args.tgtdict
        ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True,
            )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
        output_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        input_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))

    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=0,
                end=offsets[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    def make_all_alignments():
        if args.trainpref and os.path.exists(args.trainpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(
                args.trainpref + "." + args.align_suffix,
                "train.align",
                num_workers=args.workers,
            )
        if args.validpref and os.path.exists(args.validpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(
                args.validpref + "." + args.align_suffix,
                "valid.align",
                num_workers=args.workers,
            )
        if args.testpref and os.path.exists(args.testpref + "." +
                                            args.align_suffix):
            make_binary_alignment_dataset(
                args.testpref + "." + args.align_suffix,
                "test.align",
                num_workers=args.workers,
            )

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)
    berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
    make_all(args.source_lang, berttokenizer)
    if args.align_suffix:
        make_all_alignments()

    logger.info("Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding="utf-8") as align_file:
            with open(src_file_name, "r", encoding="utf-8") as src_file:
                with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang,
                                                 args.target_lang),
                ),
                "w",
                encoding="utf-8",
        ) as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Ejemplo n.º 2
0
def main(args):
    utils.import_user_module(args)
    print(args)
    os.makedirs(args.dest_dir, exist_ok=True)
    target = not args.only_source
    task = tasks.get_task(args.task)
    all_langs = list(set(args.source_langs + args.target_langs))

    def train_path(src_lang, tgt_lang, lang, prefix=args.train_pre, tok=None):
        path = "{}.{}-{}{}".format(prefix, src_lang, tgt_lang,
                                   ("." + lang) if lang else "")
        if tok:
            path += ".tok"
        return path

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        if type(lang) == list:
            lang = '-'.join(sorted(list(set(lang))))
        return os.path.join(args.dest_dir,
                            file_name(args.out_pre + prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def features_path(feature_pre, lang):
        return dest_path(feature_pre, lang) + ".txt"

    def build_dictionary(filenames):
        # assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.threshold,
            nwords=args.nwords,
            padding_factor=args.padding_factor,
        )

    def tokenize_file(prefix):
        if prefix:
            input_path = train_path(sl, tl, sl, prefix=prefix)
            tokenize(input_path,
                     input_path + '.tok',
                     model=args.model,
                     lowercase=args.lowercase)
            input_path = train_path(sl, tl, tl, prefix=prefix)
            tokenize(input_path,
                     input_path + '.tok',
                     model=args.model,
                     lowercase=args.lowercase)

    for sl in args.source_langs:
        # if os.path.exists(dict_path(sl)):
        #     raise FileExistsError(dict_path(sl))
        for tl in args.target_langs:
            # if os.path.exists(dict_path(tl)):
            #     raise FileExistsError(dict_path(tl))
            if sl == tl:
                raise ValueError(
                    "Source language and target language lists cannot overlap."
                )
            if args.model:
                for pref in (args.train_pre, args.valid_pre, args.test_pre):
                    tokenize_file(pref)

    if args.join_dict:
        joined_dict = build_dictionary({
            train_path(sl, tl, sl, tok=args.model)
            for sl in args.source_langs for tl in args.target_langs
        } | {
            train_path(sl, tl, tl, tok=args.model)
            for sl in args.source_langs for tl in args.target_langs
        })
        for lang in all_langs:
            joined_dict.save(dict_path(lang))
    else:
        dicts = {}
        for sl in args.source_langs:
            dicts[sl] = build_dictionary({
                train_path(sl, tl, sl, tok=args.model)
                for tl in args.target_langs
            })
        for tl in args.target_langs:
            dicts[tl] = build_dictionary({
                train_path(sl, tl, tl, tok=args.model)
                for sl in args.source_langs
            })
        for lang, dic in dicts.items():
            dic.save(dict_path(lang))

    # Convert vocabulary to features if necessary
    def convert_dict_to_examples(dic):
        """Read a list of `InputExample`s from an input file."""
        examples = []
        unique_id = 0
        for i, sym in enumerate(dic.symbols):
            if i < dic.nspecial:
                continue
            if "madeupword" in sym:
                continue
            text_a = sym
            text_b = None
            examples.append(
                InputExample(unique_id=unique_id, text_a=text_a,
                             text_b=text_b))
            unique_id += 1
        return examples

    def dict_to_wordlist(dic):
        """Read a list of `InputExample`s from an input file."""
        wordlist = [
            sym for i, sym in enumerate(dic.symbols)
            if i >= dic.nspecial and "madeupword" not in sym
        ]
        return wordlist

    if args.local_rank == -1 or not args.cuda:
        device = torch.device("cuda:{}".format(
            args.cuda) if torch.cuda.is_available() and args.cuda else "cpu")
        n_gpu = 0
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    if args.model:
        if "bert" in args.model:
            if args.join_dict:
                examples = convert_dict_to_examples(joined_dict)
                write_features_from_examples(examples,
                                             features_path(
                                                 args.model, all_langs),
                                             args.model,
                                             args.layer,
                                             device,
                                             args.batch_size,
                                             max_tokens=3,
                                             tokenized=True,
                                             local_rank=args.local_rank,
                                             n_gpu=n_gpu,
                                             lowercase=args.lowercase,
                                             pool=args.pool)
            else:
                for lang, dic in dicts.items():
                    examples = convert_dict_to_examples(dic)
                    write_features_from_examples(examples,
                                                 features_path(
                                                     args.model, lang),
                                                 args.model,
                                                 args.layer,
                                                 device,
                                                 args.batch_size,
                                                 max_tokens=3,
                                                 tokenized=True,
                                                 local_rank=args.local_rank,
                                                 n_gpu=n_gpu,
                                                 lowercase=args.lowercase,
                                                 pool=args.pool)
        elif "xlmr" in args.model:
            if args.join_dict:
                wordlist = dict_to_wordlist(joined_dict)
                wordlist_to_xlmr_features(joined_dict,
                                          features_path(args.model, all_langs),
                                          args.model, args.layers)
            else:
                for lang, dic in dicts.items():
                    wordlist = dict_to_wordlist(dic)
                    wordlist_to_xlmr_features(wordlist,
                                              features_path(args.model, lang),
                                              args.model, args.layers)

    def make_binary_dataset(vocab, input_prefix, output_prefix, src_lang,
                            tgt_lang, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}.{}-{}.{}".format(input_prefix, src_lang, tgt_lang,
                                          lang)
        if args.model:
            input_file += ".tok"
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = multiprocessing.Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (args, input_file, vocab, prefix, src_lang, tgt_lang, lang,
                     offsets[worker_id], offsets[worker_id + 1]),
                    callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, src_lang, tgt_lang, lang, "bin"),
                                          impl=args.dataset_impl,
                                          vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, src_lang,
                                                     tgt_lang, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(
            dataset_dest_file(args, output_prefix, src_lang, tgt_lang, lang,
                              "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            vocab.unk_word,
        ))

    def make_binary_alignment_dataset(input_prefix, output_prefix, src, tgt,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result['nseq']

        parse_alignment = lambda s: torch.IntTensor(
            [int(t) for t in s.split()])
        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = multiprocessing.Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (args, input_file, parse_alignment, prefix, src, tgt,
                     offsets[worker_id], offsets[worker_id + 1]),
                    callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, src, tgt, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(input_file,
                                          parse_alignment,
                                          lambda t: ds.add_item(t),
                                          offset=0,
                                          end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, src, tgt)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(
            dataset_dest_file(args, output_prefix, src, tgt, None, "idx"))

        print("| [alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))

    def make_dataset(vocab,
                     input_prefix,
                     output_prefix,
                     src_lang,
                     tgt_lang,
                     lang,
                     num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(src_lang, tgt_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, src_lang,
                                tgt_lang, lang, num_workers)

    def make_all(src_lang, tgt_lang):
        if args.train_pre:
            make_dataset(joined_dict if args.join_dict else dicts[src_lang],
                         args.train_pre,
                         "train",
                         src_lang,
                         tgt_lang,
                         src_lang,
                         num_workers=args.workers)
            make_dataset(joined_dict if args.join_dict else dicts[tgt_lang],
                         args.train_pre,
                         "train",
                         src_lang,
                         tgt_lang,
                         tgt_lang,
                         num_workers=args.workers)
        if args.valid_pre:
            make_dataset(joined_dict if args.join_dict else dicts[src_lang],
                         args.valid_pre,
                         "valid",
                         src_lang,
                         tgt_lang,
                         src_lang,
                         num_workers=args.workers)
            make_dataset(joined_dict if args.join_dict else dicts[tgt_lang],
                         args.valid_pre,
                         "valid",
                         src_lang,
                         tgt_lang,
                         tgt_lang,
                         num_workers=args.workers)
        if args.test_pre:
            make_dataset(joined_dict if args.join_dict else dicts[src_lang],
                         args.test_pre,
                         "test",
                         src_lang,
                         tgt_lang,
                         src_lang,
                         num_workers=args.workers)
            make_dataset(joined_dict if args.join_dict else dicts[tgt_lang],
                         args.test_pre,
                         "test",
                         src_lang,
                         tgt_lang,
                         tgt_lang,
                         num_workers=args.workers)

    def make_all_alignments(src, tgt):
        if args.train_pre:
            train_align_path = args.train_pre + ".{}-{}.".format(
                src, tgt) + args.align_suffix
            make_binary_alignment_dataset(train_align_path,
                                          "train.align",
                                          src,
                                          tgt,
                                          num_workers=args.workers)
        if args.valid_pre:
            valid_align_path = args.valid_pre + ".{}-{}.".format(
                src, tgt) + args.align_suffix
            make_binary_alignment_dataset(valid_align_path,
                                          "valid.align",
                                          src,
                                          tgt,
                                          num_workers=args.workers)
        if args.test_pre:
            test_align_path = args.test_pre + ".{}-{}.".format(
                src, tgt) + args.align_suffix
            make_binary_alignment_dataset(test_align_path,
                                          "test.align",
                                          src,
                                          tgt,
                                          num_workers=args.workers)

    for src in args.source_langs:
        for tgt in args.target_langs:
            make_all(src, tgt)
            if args.align_suffix:
                make_all_alignments(src, tgt)

    print("| Wrote preprocessed data to {}".format(args.dest_dir))
def main(args):
    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    # group.add_argument("--convert_raw", action="store_true", help="convert_raw")
    # group.add_argument("--convert_with_bpe", action="store_true", help="convert_with_bpe")
    # group.add_argument('--bpe_code', metavar='FILE', help='bpe_code')

    # new_prefix, src_tree_file, tgt_tree_file

    if args.convert_raw:
        print(f'start --- args.convert_raw')
        raise NotImplementedError

    if args.convert_raw_only:
        print(f'Finish!.')
        return

    remove_root = not args.no_remove_root
    take_pos_tag = not args.no_take_pos_tag
    take_nodes = not args.no_take_nodes
    reverse_node = not args.no_reverse_node
    no_collapse = args.no_collapse
    # remove_root =, take_pos_tag =, take_nodes =
    print(f'remove_root: {remove_root}')
    print(f'take_pos_tag: {take_pos_tag}')
    print(f'take_nodes: {take_nodes}')
    print(f'reverse_node: {reverse_node}')
    print(f'no_collapse: {no_collapse}')

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def share_dict_path():
        return args.share_dict_txt

    def build_shared_nstack2seq_dictionary(_src_file, _tgt_file):
        d = dictionary.Dictionary()
        print(f'Build dict on src_file: {_src_file}')
        NstackTreeTokenizer.acquire_vocab_multithread(
            _src_file,
            d,
            tokenize_line,
            num_workers=args.workers,
            remove_root=remove_root,
            take_pos_tag=take_pos_tag,
            take_nodes=take_nodes,
            no_collapse=no_collapse,
        )
        print(f'Build dict on tgt_file: {_tgt_file}')
        dictionary.Dictionary.add_file_to_dictionary(_tgt_file,
                                                     d,
                                                     tokenize_line,
                                                     num_workers=args.workers)
        d.finalize(threshold=args.thresholdsrc if src else args.thresholdtgt,
                   nwords=args.nwordssrc if src else args.nwordstgt,
                   padding_factor=args.padding_factor)
        print(f'Finish building vocabulary: size {len(d)}')
        return d

    def build_nstack_source_dictionary(_src_file):
        d = dictionary.Dictionary()
        print(f'Build dict on src_file: {_src_file}')
        NstackTreeTokenizer.acquire_vocab_multithread(
            _src_file,
            d,
            tokenize_line,
            num_workers=args.workers,
            remove_root=remove_root,
            take_pos_tag=take_pos_tag,
            take_nodes=take_nodes,
            no_collapse=no_collapse,
        )
        d.finalize(threshold=args.thresholdsrc if src else args.thresholdtgt,
                   nwords=args.nwordssrc if src else args.nwordstgt,
                   padding_factor=args.padding_factor)
        print(f'Finish building src vocabulary: size {len(d)}')
        return d

    def build_target_dictionary(_tgt_file):
        # assert src ^ tgt
        print(f'Build dict on tgt: {_tgt_file}')
        d = task.build_dictionary(
            [_tgt_file],
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )
        print(f'Finish building tgt vocabulary: size {len(d)}')
        return d

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    # if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
    #     raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_file = f'{args.trainpref}.{args.source_lang}'
            tgt_file = f'{args.trainpref}.{args.target_lang}'
            src_dict = build_shared_nstack2seq_dictionary(src_file, tgt_file)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_nstack_source_dictionary(
                train_path(args.source_lang))

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_target_dictionary(train_path(
                    args.target_lang))
        else:
            tgt_dict = None
        # raise NotImplementedError(f'only allow args.joined_dictionary for now')

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        pool = None

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))

        def consumer(tensor):
            ds.add_item(tensor)

        stat = BinarizerDataset.export_binarized_dataset(
            input_file,
            vocab,
            consumer,
            add_if_not_exist=False,
            num_workers=num_workers,
        )

        ntok = stat['ntok']
        nseq = stat['nseq']
        nunk = stat['nunk']

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            nseq,
            ntok,
            100 * nunk / ntok,
            vocab.unk_word,
        ))

    def make_binary_nstack_dataset(vocab, input_prefix, output_prefix, lang,
                                   num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")

        dss = {
            modality: NstackSeparateIndexedDatasetBuilder(
                dataset_dest_file_dptree(args, output_prefix, lang, 'bin',
                                         modality))
            for modality in NSTACK_KEYS
        }

        def consumer(example):
            for modality, tensor in example.items():
                dss[modality].add_item(tensor)

        stat = NstackTreeMergeBinarizerDataset.export_binarized_separate_dataset(
            input_file,
            vocab,
            consumer,
            add_if_not_exist=False,
            num_workers=num_workers,
            remove_root=remove_root,
            take_pos_tag=take_pos_tag,
            take_nodes=take_nodes,
            reverse_node=reverse_node,
            no_collapse=no_collapse,
        )
        ntok = stat['ntok']
        nseq = stat['nseq']
        nunk = stat['nunk']

        for modality, ds in dss.items():
            ds.finalize(
                dataset_dest_file_dptree(args, output_prefix, lang, "idx",
                                         modality))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            nseq,
            ntok,
            100 * nunk / ntok,
            vocab.unk_word,
        ))
        for modality, ds in dss.items():
            print(f'\t{modality}')

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_dptree_dataset(vocab,
                            input_prefix,
                            output_prefix,
                            lang,
                            num_workers=1):
        if args.output_format != "binary":
            raise NotImplementedError(
                f'output format {args.output_format} not impl')

        make_binary_nstack_dataset(vocab, input_prefix, output_prefix, lang,
                                   num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            print(
                f'!!!! Warning..... Not during en-fr target because already done!.....'
            )
            # make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)

        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.eval_workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.eval_workers)

    def make_all_src(lang, vocab):
        if args.trainpref:
            # print(f'!!!! Warning..... Not during en-fr source because already done!.....')
            make_dptree_dataset(vocab,
                                args.trainpref,
                                "train",
                                lang,
                                num_workers=args.workers)

        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dptree_dataset(vocab,
                                    validpref,
                                    outprefix,
                                    lang,
                                    num_workers=args.eval_workers)

        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dptree_dataset(vocab,
                                    testpref,
                                    outprefix,
                                    lang,
                                    num_workers=args.eval_workers)

    def make_all_tgt(lang, vocab):
        make_all(lang, vocab)

    # make_all_src(args.source_lang, src_dict)
    print(f'|||| WARNIONG no processing for source.')
    if target:
        make_all_tgt(args.target_lang, tgt_dict)
        # print(f'No makign target')

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        raise NotImplementedError('alignfile Not impl at the moment')
Ejemplo n.º 4
0
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = xlnet_dictionary.XLNetDictionary.load(args.srcdict)
            print('load xlnet dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = xlnet_dictionary.XLNetDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = xlnet_dictionary.XLNetDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import XLNetConfig, XLNetTokenizer
        import torch

        tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_wids = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_wids.append(dict.sep_index)
                wids = tokenizer.convert_tokens_to_ids(sent)
                # wids_vocab = [dict.index(word) for word in sent]
                # assert wids == wids_vocab, 'word indices should be the same!'
                article_wids.extend(wids)
                for wid in wids:
                    if wid == dict.unk_index:
                        num_unk_token += 1
                    num_token += 1

            num_seq += 1
            tensor = torch.IntTensor(article_wids)
            # print( dict.string_complete(tensor) )
            ds.add_item(tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, num_seq, num_token,
            100 * num_unk_token / num_token,
            dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

        #
        #     n_seq_tok = [0, 0]
        #     replaced = Counter()
        #
        #     def merge_result(worker_result):
        #         replaced.update(worker_result["replaced"])
        #         n_seq_tok[0] += worker_result["nseq"]
        #         n_seq_tok[1] += worker_result["ntok"]
        #
        #     input_file = "{}{}".format(
        #         input_prefix, ("." + lang) if lang is not None else ""
        #     )
        #     offsets = Binarizer.find_offsets(input_file, num_workers)
        #     pool = None
        #     if num_workers > 1:
        #         pool = Pool(processes=num_workers - 1)
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             pool.apply_async(
        #                 binarize,
        #                 (
        #                     args,
        #                     input_file,
        #                     vocab,
        #                     prefix,
        #                     lang,
        #                     offsets[worker_id],
        #                     offsets[worker_id + 1]
        #                 ),
        #                 callback=merge_result
        #             )
        #         pool.close()
        #
        #     ds = indexed_dataset.IndexedDatasetBuilder(
        #         dataset_dest_file(args, output_prefix, lang, "bin")
        #     )
        #     merge_result(
        #         Binarizer.binarize(
        #             input_file, vocab, lambda t: ds.add_item(t),
        #             offset=0, end=offsets[1]
        #         )
        #     )
        #     if num_workers > 1:
        #         pool.join()
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             temp_file_path = dataset_dest_prefix(args, prefix, lang)
        #             ds.merge_file_(temp_file_path)
        #             os.remove(indexed_dataset.data_file_path(temp_file_path))
        #             os.remove(indexed_dataset.index_file_path(temp_file_path))
        #
        #     ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        #
        #     print(
        #         "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
        #             lang,
        #             input_file,
        #             n_seq_tok[0],
        #             n_seq_tok[1],
        #             100 * sum(replaced.values()) / n_seq_tok[1],
        #             vocab.unk_word,
        #         )
        #     )

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)

            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Ejemplo n.º 5
0
def main(args):
    # setup some basic things
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(filename=os.path.join(args.destdir,
                                                  "preprocess.log"), ))
    logger.info(args)

    assert (
        args.dataset_impl != "huffman"
    ), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."

    # build dictionaries

    target = not args.only_source

    if not args.srcdict and os.path.exists(
            _dict_path(args.source_lang, args.destdir)):
        raise FileExistsError(_dict_path(args.source_lang, args.destdir))

    if (target and not args.tgtdict
            and os.path.exists(_dict_path(args.target_lang, args.destdir))):
        raise FileExistsError(_dict_path(args.target_lang, args.destdir))

    task = tasks.get_task(args.task)

    if args.joined_dictionary:
        assert (
            not args.srcdict or not args.tgtdict
        ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = _build_dictionary(
                {
                    _train_path(lang, args.trainpref)
                    for lang in [args.source_lang, args.target_lang]
                },
                task=task,
                args=args,
                src=True,
            )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = _build_dictionary(
                [_train_path(args.source_lang, args.trainpref)],
                task=task,
                args=args,
                src=True,
            )

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = _build_dictionary(
                    [_train_path(args.target_lang, args.trainpref)],
                    task=task,
                    args=args,
                    tgt=True,
                )
        else:
            tgt_dict = None

    # save dictionaries

    src_dict.save(_dict_path(args.source_lang, args.destdir))
    if target and tgt_dict is not None:
        tgt_dict.save(_dict_path(args.target_lang, args.destdir))

    if args.dict_only:
        return

    _make_all(args.source_lang, src_dict, args)
    if target:
        _make_all(args.target_lang, tgt_dict, args)

    # align the datasets if needed
    if args.align_suffix:
        _make_all_alignments(args)

    logger.info("Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        _align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
Ejemplo n.º 6
0
def main(args):
    utils.import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))
    if not args.edgedict:
        raise FileExistsError(args.edgedict)

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    edge_dict = task.load_dictionary(args.edgedict)

    src_dict.save(dict_path(args.source_lang))
    edge_dict.save(dict_path('edge'))

    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        print(args.testpref)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    def _make_edge_dataset(vocab, input_prefix, output_prefix, lang,
                           num_workers, output_text_file):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")

        ds = []
        merge_result(
            Binarizer.binarize_graph(input_file, vocab,
                                     lambda t: ds.append(t)))
        import json
        with open(output_text_file, 'w') as f:
            for line in ds:
                f.write(json.dumps(line.numpy().tolist()) + '\n')

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            vocab.unk_word,
        ))

    def make_edge_dataset(vocab,
                          input_prefix,
                          output_prefix,
                          lang,
                          num_workers=1):
        output_text_file = dest_path(
            output_prefix +
            ".{}-{}".format(args.source_lang, args.target_lang),
            lang,
        )
        _make_edge_dataset(vocab, input_prefix, output_prefix, lang,
                           num_workers, output_text_file)

    def make_edge_all(lang, vocab):
        if args.trainpref:
            make_edge_dataset(vocab,
                              args.trainpref,
                              "train",
                              lang,
                              num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_edge_dataset(vocab,
                                  validpref,
                                  outprefix,
                                  lang,
                                  num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_edge_dataset(vocab,
                                  testpref,
                                  outprefix,
                                  lang,
                                  num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    make_edge_all(args.edge_lang, edge_dict)
    if target:
        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args):
    utils.import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_sent_doc_dataset(input_prefix, output_prefix, lang,
                                     num_workers, output_lang,
                                     output_text_file):
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")

        ds = []
        merge_result(
            Binarizer.binarize_sent_doc(input_file, lambda t: ds.append(t)))

        import json
        with open(output_text_file, 'w') as f:
            for line in ds:
                f.write(json.dumps(line.numpy().tolist()) + '\n')

        print("| [{}] {}: {} sents, {} tokens".format(
            output_lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
        ))

    def make_binary_dataset_hierarchical(vocab, input_prefix, output_prefix,
                                         lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize_hierarchical,
                                 (args, input_file, vocab, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl)
        merge_result(
            Binarizer.binarize_hierarchical(input_file,
                                            vocab,
                                            lambda t: ds.add_item(t),
                                            offset=0,
                                            end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            vocab.unk_word,
        ))

    def make_sent_doc_dataset(input_prefix,
                              output_prefix,
                              lang,
                              num_workers=1,
                              output_lang=None,
                              output_text_file=None):
        make_binary_sent_doc_dataset(input_prefix,
                                     output_prefix,
                                     lang,
                                     num_workers,
                                     output_lang=output_lang,
                                     output_text_file=output_text_file)

    import nltk

    def split_sentence_and_copy(input_file, output_file):
        with open(input_file) as f:
            con = f.readlines()
        sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

        def split_sentence(paragraph, tokenizer):
            # tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
            sentences = tokenizer.tokenize(paragraph)
            return sentences

        with open(output_file, 'w') as f:
            for line in con:
                line = ' story_separator_special_tag '.join([
                    ' sentence_separator_special_tag '.join(
                        split_sentence(paragraph, sentence_tokenizer))
                    for paragraph in line.split('story_separator_special_tag')
                    if paragraph
                ])
                f.write(line + '\n')

    def make_dataset_hierarchical(vocab,
                                  input_prefix,
                                  output_prefix,
                                  lang,
                                  num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            split_sentence_and_copy(file_name(input_prefix, lang),
                                    output_text_file)
            # shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset_hierarchical(vocab, input_prefix,
                                             output_prefix, lang, num_workers)

    def make_sent(lang):
        if args.trainpref:
            output_text_file = dest_path(
                "train" + ".{}-{}".format(args.source_lang, args.target_lang),
                'sent',
            )
            make_sent_doc_dataset(args.trainpref,
                                  "train",
                                  lang,
                                  num_workers=args.workers,
                                  output_lang='sent',
                                  output_text_file=output_text_file)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                output_text_file = dest_path(
                    outprefix +
                    ".{}-{}".format(args.source_lang, args.target_lang),
                    'sent',
                )
                make_sent_doc_dataset(validpref,
                                      outprefix,
                                      lang,
                                      num_workers=args.workers,
                                      output_lang='sent',
                                      output_text_file=output_text_file)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                output_text_file = dest_path(
                    outprefix +
                    ".{}-{}".format(args.source_lang, args.target_lang),
                    'sent',
                )
                make_sent_doc_dataset(testpref,
                                      outprefix,
                                      lang,
                                      num_workers=args.workers,
                                      output_lang='sent',
                                      output_text_file=output_text_file)

    def make_doc(lang):
        if args.trainpref:
            output_text_file = dest_path(
                "train" + ".{}-{}".format(args.source_lang, args.target_lang),
                'doc',
            )
            make_sent_doc_dataset(args.trainpref,
                                  "train",
                                  lang,
                                  num_workers=args.workers,
                                  output_lang='doc',
                                  output_text_file=output_text_file)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                output_text_file = dest_path(
                    outprefix +
                    ".{}-{}".format(args.source_lang, args.target_lang),
                    'doc',
                )
                make_sent_doc_dataset(validpref,
                                      outprefix,
                                      lang,
                                      num_workers=args.workers,
                                      output_lang='doc',
                                      output_text_file=output_text_file)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                output_text_file = dest_path(
                    outprefix +
                    ".{}-{}".format(args.source_lang, args.target_lang),
                    'doc',
                )
                make_sent_doc_dataset(testpref,
                                      outprefix,
                                      lang,
                                      num_workers=args.workers,
                                      output_lang='doc',
                                      output_text_file=output_text_file)

    def make_all_hierarchical(lang, vocab):
        if args.trainpref:
            make_dataset_hierarchical(vocab,
                                      args.trainpref,
                                      "train",
                                      lang,
                                      num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset_hierarchical(vocab,
                                          validpref,
                                          outprefix,
                                          lang,
                                          num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset_hierarchical(vocab,
                                          testpref,
                                          outprefix,
                                          lang,
                                          num_workers=args.workers)

    make_all_hierarchical(args.source_lang, src_dict)
    make_sent(args.sent_lang)
    make_doc(args.doc_lang)
    if target:

        def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers):
            print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
            n_seq_tok = [0, 0]
            replaced = Counter()

            def merge_result(worker_result):
                replaced.update(worker_result["replaced"])
                n_seq_tok[0] += worker_result["nseq"]
                n_seq_tok[1] += worker_result["ntok"]

            input_file = "{}{}".format(
                input_prefix, ("." + lang) if lang is not None else "")
            offsets = Binarizer.find_offsets(input_file, num_workers)
            pool = None
            if num_workers > 1:
                pool = Pool(processes=num_workers - 1)
                for worker_id in range(1, num_workers):
                    prefix = "{}{}".format(output_prefix, worker_id)
                    pool.apply_async(
                        binarize, (args, input_file, vocab, prefix, lang,
                                   offsets[worker_id], offsets[worker_id + 1]),
                        callback=merge_result)
                pool.close()

            ds = indexed_dataset.make_builder(dataset_dest_file(
                args, output_prefix, lang, "bin"),
                                              impl=args.dataset_impl)
            merge_result(
                Binarizer.binarize(input_file,
                                   vocab,
                                   lambda t: ds.add_item(t),
                                   offset=0,
                                   end=offsets[1]))
            if num_workers > 1:
                pool.join()
                for worker_id in range(1, num_workers):
                    prefix = "{}{}".format(output_prefix, worker_id)
                    temp_file_path = dataset_dest_prefix(args, prefix, lang)
                    ds.merge_file_(temp_file_path)
                    os.remove(indexed_dataset.data_file_path(temp_file_path))
                    os.remove(indexed_dataset.index_file_path(temp_file_path))

            ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

            print(
                "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                    lang,
                    input_file,
                    n_seq_tok[0],
                    n_seq_tok[1],
                    100 * sum(replaced.values()) / n_seq_tok[1],
                    vocab.unk_word,
                ))

        def make_dataset(vocab,
                         input_prefix,
                         output_prefix,
                         lang,
                         num_workers=1):
            if args.dataset_impl == "raw":
                # Copy original text file to destination folder
                output_text_file = dest_path(
                    output_prefix +
                    ".{}-{}".format(args.source_lang, args.target_lang),
                    lang,
                )
                shutil.copyfile(file_name(input_prefix, lang),
                                output_text_file)
            else:
                make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                    num_workers)

        def make_all(lang, vocab):
            if args.trainpref:
                make_dataset(vocab,
                             args.trainpref,
                             "train",
                             lang,
                             num_workers=args.workers)
            if args.validpref:
                for k, validpref in enumerate(args.validpref.split(",")):
                    outprefix = "valid{}".format(k) if k > 0 else "valid"
                    make_dataset(vocab,
                                 validpref,
                                 outprefix,
                                 lang,
                                 num_workers=args.workers)
            if args.testpref:
                for k, testpref in enumerate(args.testpref.split(",")):
                    outprefix = "test{}".format(k) if k > 0 else "test"
                    make_dataset(vocab,
                                 testpref,
                                 outprefix,
                                 lang,
                                 num_workers=args.workers)

        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args):
    import_user_module(args)

    print(args)

    # to control what preprocessing needs to be run (as they take both time and storage so we avoid running repeatedly)
    run_basic = True
    # this includes:
    # src: build src dictionary, copy the raw data to dir; build src binary data (need to refactor later if unneeded)
    # tgt: split target non-pointer actions and pointer values into separate files; build tgt dictionary
    run_act_states = True
    # this includes:
    # run the state machine reformer to get
    # a) training data: input and output, pointer values;
    # b) states information to facilitate modeling;
    # takes about 1 hour and 13G space on CCC
    run_roberta_emb = True
    # this includes:
    # for src sentences, use pre-trained RoBERTa model to extract contextual embeddings for each word;
    # takes about 10min for RoBERTa base and 30 mins for RoBERTa large and 2-3G space;
    # this needs GPU and only needs to run once for the English sentences, which does not change for different oracles;
    # thus the embeddings are stored separately from the oracles.

    if os.path.exists(args.destdir):
        print(f'binarized actions and states directory {args.destdir} already exists; not rerunning.')
        run_basic = False
        run_act_states = False
    if os.path.exists(args.embdir):
        print(f'pre-trained embedding directory {args.embdir} already exists; not rerunning.')
        run_roberta_emb = False

    os.makedirs(args.destdir, exist_ok=True)
    os.makedirs(args.embdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    # preprocess target actions files, to split '.actions' to '.actions_nopos' and '.actions_pos'
    # when building dictionary on the target actions sequences
    # split the action file into two files, one without arc pointer and one with only arc pointer values
    # and the dictionary is only built on the no pointer actions
    if run_basic:
        assert args.target_lang == 'actions', 'target extension must be "actions"'
        actions_files = [f'{pref}.{args.target_lang}' for pref in (args.trainpref, args.validpref, args.testpref)]
        task.split_actions_pointer_files(actions_files)
        args.target_lang_nopos = 'actions_nopos'    # only build dictionary without pointer values
        args.target_lang_pos = 'actions_pos'

    # set tokenizer
    tokenize = task.tokenize if hasattr(task, 'tokenize') else tokenize_line

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt

        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
            # tokenize separator is taken care inside task
        )

    # build dictionary and save

    if run_basic:
        if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
            raise FileExistsError(dict_path(args.source_lang))
        if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
            raise FileExistsError(dict_path(args.target_lang))

        if args.joined_dictionary:
            assert not args.srcdict or not args.tgtdict, \
                "cannot use both --srcdict and --tgtdict with --joined-dictionary"

            if args.srcdict:
                src_dict = task.load_dictionary(args.srcdict)
            elif args.tgtdict:
                src_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
                src_dict = build_dictionary(
                    {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
                )
            tgt_dict = src_dict
        else:
            if args.srcdict:
                src_dict = task.load_dictionary(args.srcdict)
            else:
                assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
                src_dict = build_dictionary([train_path(args.source_lang)], src=True)

            if target:
                if args.tgtdict:
                    tgt_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                    tgt_dict = build_dictionary([train_path(args.target_lang_nopos)], tgt=True)
            else:
                tgt_dict = None

        src_dict.save(dict_path(args.source_lang))
        if target and tgt_dict is not None:
            tgt_dict.save(dict_path(args.target_lang_nopos))

    # save binarized preprocessed files

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                        False,    # note here we shut off append eos
                        tokenize
                    ),
                    callback=merge_result
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl, vocab_size=len(vocab), dtype=np.int64)
        merge_result(
            Binarizer.binarize(
                input_file, vocab, lambda t: ds.add_item(t),
                offset=0, end=offsets[1],
                append_eos=False,
                tokenize=tokenize
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            )
        )

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1, dataset_impl=args.dataset_impl):
        if dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)

    def make_all(lang, vocab, dataset_impl=args.dataset_impl):
        if args.trainpref:
            make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers, dataset_impl=dataset_impl)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl)

    # NOTE we do not encode the source sentences with dictionary, as the source embeddings are directly provided
    # from RoBERTa, thus the source dictionary here is of no use
    if run_basic:
        make_all(args.source_lang, src_dict, dataset_impl='raw')
        make_all(args.source_lang, src_dict, dataset_impl='mmap')
        # above: just leave for the sake of model to run without too much change
        # NOTE there are <unk> in valid and test set for target actions
        # if target:
        #     make_all(args.target_lang_nopos, tgt_dict)

        # NOTE targets (input, output, pointer values) are now all included in the state generation process

        # binarize pointer values and save to file

        # TODO make naming convention clearer
        # assume one training file, one validation file, and one test file
        # for pos_file, split in [(f'{pref}.actions_pos', split) for pref, split in
        #                         [(args.trainpref, 'train'), (args.validpref, 'valid'), (args.testpref, 'test')]]:
        #     out_pref = os.path.join(args.destdir, split)
        #     task.binarize_actions_pointer_file(pos_file, out_pref)

    # save action states information to assist training with auxiliary info
    # assume one training file, one validation file, and one test file
    if run_act_states:
        task_obj = task(args, tgt_dict=tgt_dict)
        for prefix, split in zip([args.trainpref, args.validpref, args.testpref], ['train', 'valid', 'test']):
            en_file = prefix + '.en'
            actions_file = prefix + '.actions'
            out_file_pref = os.path.join(args.destdir, split)
            task_obj.build_actions_states_info(en_file, actions_file, out_file_pref, num_workers=args.workers)

    # save RoBERTa embeddings
    # TODO refactor this code
    if run_roberta_emb:
        make_roberta_embeddings(args, tokenize=tokenize)

    print("| Wrote preprocessed oracle data to {}".format(args.destdir))
    print("| Wrote preprocessed embedding data to {}".format(args.embdir))
def prepare_dict(args):
    utils.import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False, word_level=True):
        assert src ^ tgt
        return task.build_dict(
            filenames,
            word_level=word_level,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if os.path.exists(dict_path(args.source_lang)) and \
            os.path.exists(dict_path(args.target_lang)) and \
            os.path.exists(os.path.join(args.destdir, 'dict_char.txt')):
        return task.load_dictionary(dict_path(args.source_lang)), \
               task.load_dictionary(dict_path(args.target_lang)), \
               task.load_dictionary(os.path.join(args.destdir, 'dict_char.txt'))

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    char_dict = build_dictionary(
        {train_path(lang)
         for lang in [args.source_lang, args.target_lang]},
        src=True,
        word_level=False)

    # print(src_dict)
    char_dict.save(os.path.join(args.destdir, 'dict_char.txt'))
    return src_dict, tgt_dict, char_dict
Ejemplo n.º 10
0
def main(args):
    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.copy_ext_dict:
        assert args.joined_dictionary, \
            "--joined-dictionary must be set if --copy-extended-dictionary is specified"
        assert args.workers == 1, \
            "--workers must be set to 1 if --copy-extended-dictionary is specified"

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
            )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)], src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers, copy_src_words=None):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()
        copyied = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            copyied.update(worker_result["copied"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:  # todo: not support copy 
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1]
                    ),
                    callback=merge_result  
                )
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        words_list = []

        def binarize_consumer(ids, words):
            ds.add_item(ids)
            words_list.append(words)

        merge_result(
            Binarizer.binarize(
                input_file, vocab, binarize_consumer,
                offset=0, end=offsets[1], copy_ext_dict=args.copy_ext_dict, copy_src_words=copy_src_words
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}, {:.3}% <unk> copied from src".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
                100 * sum(copyied.values()) / n_seq_tok[1]
            )
        )

        return words_list

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1, copy_src_words=None):
        if args.output_format == "binary":
            return make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers, copy_src_words)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

            return None

    def make_all(lang, vocab, source_words_list_dict=defaultdict(lambda: None)):
        words_list_dict = defaultdict(lambda: None)

        if args.trainpref:
            words_list_dict["train"] = \
                make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers,
                             copy_src_words=source_words_list_dict['train'])
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                words_list_dict["valid"] = \
                    make_dataset(vocab, validpref, outprefix, lang, copy_src_words=source_words_list_dict['valid'])
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                words_list_dict["test"] = \
                    make_dataset(vocab, testpref, outprefix, lang, copy_src_words=source_words_list_dict['test'])

        return words_list_dict

    source_words_list_dict = make_all(args.source_lang, src_dict)
    if target:
        target_words_list_dict = make_all(args.target_lang, tgt_dict, source_words_list_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if False: #args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

        with open(
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
                ),
                "w", encoding='utf-8'
        ) as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


    if args.alignfile:
        from fairseq.tokenizer import tokenize_line
        import numpy as np
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_labels_list = []
        tgt_labels_list = []
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        src_words = tokenize_line(s)
                        tgt_words = tokenize_line(t)
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
                        src_labels = np.ones(len(src_words), int)
                        tgt_labels = np.ones(len(tgt_words), int)
                        for sai, tai in ai:
                            if int(tai) >= len(tgt_words):
                                print('Bad case:')
                                print(tgt_words)
                                print(ai)
                                continue
                            src_word = src_words[int(sai)]
                            tgt_word = tgt_words[int(tai)]
                            if src_word == tgt_word:
                                src_labels[int(sai)] = 0
                                tgt_labels[int(tai)] = 0
                        src_labels_list.append(src_labels)
                        tgt_labels_list.append(tgt_labels)

        save_label_file(os.path.join(args.destdir, "train.label.{}.txt".format(args.source_lang)), src_labels_list)
        save_label_file(os.path.join(args.destdir, "train.label.{}.txt".format(args.target_lang)), tgt_labels_list)
Ejemplo n.º 11
0
def main(args):
    utils.import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, vocab, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl,
                                          vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            vocab.unk_word,
        ))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))
Ejemplo n.º 12
0
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = roberta_dictionary.RobertaDictionary.load_json(
                args.srcdict)
            # src_dict.save('roberta-vocab/roberta-base-vocab.txt')
            print('load bert dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = roberta_dictionary.RobertaDictionary.load_json(
                    args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = roberta_dictionary.RobertaDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import RobertaTokenizer
        import torch

        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        output_ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, 'article_next', "bin"))
        truncated_number = 512
        output_length = 256

        CLS_TOKEN = '<s>'
        SEP_TOKEN = '</s>'

        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_toks = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_toks.append(SEP_TOKEN)
                article_toks.extend(sent)
            article_segments = []
            output_segments = []
            tmp_seg = []
            for i, tok in enumerate(article_toks):
                if len(tmp_seg) == 0:
                    tmp_seg.append(CLS_TOKEN)
                tmp_seg.append(tok)
                if tok == SEP_TOKEN:
                    tmp_seg.append(tok)
                if len(tmp_seg) >= truncated_number:
                    tmp_seg = tmp_seg[:truncated_number]
                    if tmp_seg[-1] != SEP_TOKEN:
                        tmp_seg[-1] = SEP_TOKEN
                    tmp_output = article_toks[
                        i + 1:min(i + 1 + output_length, len(article_toks))]
                    if len(tmp_output) < 0.3 * output_length:
                        break
                    article_segments.append(
                        tokenizer.convert_tokens_to_ids(tmp_seg))
                    output_segments.append(
                        tokenizer.convert_tokens_to_ids(tmp_output))
                    tmp_seg = []
            assert len(article_segments) == len(output_segments)
            for i in range(len(article_segments)):
                assert len(article_segments[i]) <= truncated_number
                assert len(output_segments[i]) <= output_length and len(
                    output_segments[i]) >= 0.3 * output_length
                tensor = torch.IntTensor(article_segments[i])
                ds.add_item(tensor)
                output_tensor = torch.IntTensor(output_segments[i])
                output_ds.add_item(output_tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        output_ds.finalize(
            dataset_dest_file(args, output_prefix, 'article_next', "idx"))
        print('done!')
        # print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
        #     lang, input_file, num_seq, num_token,
        #     100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        # if args.testpref:
        #     for k, testpref in enumerate(args.testpref.split(",")):
        #         outprefix = "test{}".format(k) if k > 0 else "test"
        #         make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    # if target:
    #     make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
        elif args.tgtdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
            print('load bert dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = bert_dictionary.BertDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = bert_dictionary.BertDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import BertTokenizer
        import torch

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        output_ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, 'article_next', "bin"))
        article_input = 511
        article_next = 256
        BERT_CLS_ID = tokenizer.convert_tokens_to_ids([BERT_CLS])[0]
        BERT_SEP_ID = tokenizer.convert_tokens_to_ids([BERT_SEP])[0]
        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_wids = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_wids.append(dict.sep_index)
                if len(sent) > article_input:

                    wids = []
                    temp_sent = [
                        sent[x:x + article_input]
                        for x in range(0, len(sent), article_input)
                    ]
                    for se in temp_sent:
                        se_ids = tokenizer.convert_tokens_to_ids(se)
                        wids.extend(se_ids)

                else:
                    wids = tokenizer.convert_tokens_to_ids(sent)
                # wids_vocab = [dict.index(word) for word in sent]
                # assert wids == wids_vocab, 'word indices should be the same!'
                article_wids.extend(wids)
                for wid in wids:
                    if wid == dict.unk_index:
                        num_unk_token += 1
                    num_token += 1

            article_segments = [
                article_wids[x:x + article_input]
                for x in range(0, len(article_wids), article_input)
            ]

            cur_position = 0
            for i in range(len(article_segments)):
                article_seq = article_segments[i]
                cur_position += len(article_seq)
                output_seg = article_wids[
                    cur_position:min(len(article_wids), cur_position +
                                     article_next)]
                if len(output_seg) < 0.3 * article_next:
                    continue
                num_seq += 1
                if len(article_seq) > article_input:
                    print('lang: %s, token len: %d, truncated len: %d' %
                          (lang, len(article_seq), article_input))
                if lang == 'article':
                    if article_seq[-1] != BERT_SEP_ID:
                        if article_seq[-2] != BERT_SEP_ID:
                            article_seq[-1] = BERT_SEP_ID
                    article_seq = [BERT_CLS_ID] + article_seq

                if len(output_seg) > article_next:
                    print(
                        'lang: article_next, token len: %d, truncated len: %d'
                        % (len(output_seg), article_next))

                tensor = torch.IntTensor(article_seq)
                ds.add_item(tensor)
                output_tensor = torch.IntTensor(output_seg)
                output_ds.add_item(output_tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        output_ds.finalize(
            dataset_dest_file(args, output_prefix, 'article_next', "idx"))
        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, num_seq, num_token,
            100 * num_unk_token / num_token,
            dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    # if target:
    #     make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Ejemplo n.º 14
0
def main(args):
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(
            filename=os.path.join(args.destdir, "preprocess.log")))
    logger.info(args)

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    label_dictionary, label_schema = task.load_label_dictionary(
        args, args.label_schema)
    labelled_span_parser = make_parse_labelled_spans(label_dictionary,
                                                     label_schema)

    def make_binary_labelled_spans_dataset(input_prefix, output_prefix,
                                           num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_labelled_spans,
                    (
                        args,
                        input_file,
                        labelled_span_parser,
                        prefix,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                labelled_span_parser,
                lambda t: ds.add_item(t),
                offset=0,
                end=offsets[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = "{}/{}".format(args.destdir, prefix)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[labelled spans] {}: parsed {} sentences".format(
            input_file, nseq[0]))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)

    if args.nonterm_suffix:
        if args.trainpref and os.path.exists("{}.{}".format(
                args.trainpref, args.nonterm_suffix)):
            make_binary_labelled_spans_dataset(
                "{}.{}".format(args.trainpref, args.nonterm_suffix),
                "train.nonterm",
                args.workers,
            )
        if args.validpref and os.path.exists("{}.{}".format(
                args.validpref, args.nonterm_suffix)):
            make_binary_labelled_spans_dataset(
                "{}.{}".format(args.validpref, args.nonterm_suffix),
                "valid.nonterm",
                args.workers,
            )
        if args.testpref and os.path.exists("{}.{}".format(
                args.testpref, args.nonterm_suffix)):
            make_binary_labelled_spans_dataset(
                "{}.{}".format(args.testpref, args.nonterm_suffix),
                "test.nonterm",
                args.workers,
            )
    elif args.term_suffix:
        if args.trainpref:
            make_dataset(
                label_dictionary,
                args.trainpref + "." + args.term_suffix,
                "train.term",
                args.source_lang,
                num_workers=args.workers,
            )
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid.term{}".format(k) if k > 0 else "valid.term"
                make_dataset(
                    label_dictionary,
                    validpref + "." + args.term_suffix,
                    outprefix,
                    args.source_lang,
                    num_workers=args.workers,
                )
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test.term{}".format(k) if k > 0 else "test.term"
                make_dataset(
                    label_dictionary,
                    testpref + "." + args.term_suffix,
                    outprefix,
                    args.source_lang,
                    num_workers=args.workers,
                )