Ejemplo n.º 1
0
def create_tokenizer(tokenizer_type, model_path, vocab_path):
    if tokenizer_type == 'whitespace':
        return tokenizers.create(tokenizer_type, vocab=Vocab.load(vocab_path))
    elif tokenizer_type == 'spm':
        return tokenizers.create(tokenizer_type,
                                 model_path=model_path,
                                 vocab=vocab_path)
    elif tokenizer_type == 'subword_nmt':
        return tokenizers.create(tokenizer_type,
                                 codec_path=model_path,
                                 vocab_path=vocab_path)
    elif tokenizer_type == 'yttm':
        return tokenizers.create(tokenizer_type, model_path=model_path)
    elif tokenizer_type == 'hf_bytebpe':
        return tokenizers.create(tokenizer_type,
                                 merges_file=model_path,
                                 vocab_file=vocab_path)
    elif tokenizer_type == 'hf_wordpiece':
        return tokenizers.create(tokenizer_type, vocab_file=vocab_path)
    elif tokenizer_type == 'hf_bpe':
        return tokenizers.create(tokenizer_type,
                                 merges_file=model_path,
                                 vocab_file=vocab_path)
    else:
        raise NotImplementedError
Ejemplo n.º 2
0
def create_tokenizer(tokenizer_type, model_path, vocab_path):
    if tokenizer_type == 'whitespace':
        return tokenizers.create(tokenizer_type, vocab=Vocab.load(vocab_path))
    elif tokenizer_type == 'spm':
        return tokenizers.create(tokenizer_type,
                                 model_path=model_path,
                                 vocab=vocab_path)
    elif tokenizer_type == 'subword_nmt':
        return tokenizers.create(tokenizer_type,
                                 model_path=model_path,
                                 vocab=vocab_path)
    elif tokenizer_type == 'yttm':
        return tokenizers.create(tokenizer_type, model_path=model_path)
    elif tokenizer_type in ['hf_bytebpe', 'hf_wordpiece', 'hf_bpe']:
        if huggingface.is_new_version_model_file(model_path):
            return tokenizers.create('hf_tokenizer',
                                     model_path=model_path,
                                     vocab=vocab_path)
        elif tokenizer_type == 'hf_bytebpe':
            return tokenizers.create(tokenizer_type,
                                     merges_file=model_path,
                                     vocab_file=vocab_path)
        elif tokenizer_type == 'hf_wordpiece':
            return tokenizers.create(tokenizer_type, vocab_file=vocab_path)
        elif tokenizer_type == 'hf_bpe':
            return tokenizers.create(tokenizer_type,
                                     merges_file=model_path,
                                     vocab_file=vocab_path)
    else:
        raise NotImplementedError
Ejemplo n.º 3
0
def test_spacy_tokenizer():
    en_tokenizer = SpacyTokenizer('en')
    de_tokenizer = SpacyTokenizer('de')
    gt_en_tokenized = [['Four', 'score', 'and', 'seven', 'years', 'ago', 'our', 'fathers',
                        'brought', 'forth', 'on', 'this', 'continent', ',', 'a', 'new', 'nation',
                        ',', 'conceived', 'in', 'Liberty', ',', 'and', 'dedicated', 'to', 'the',
                        'proposition', 'that', 'all', 'men', 'are', 'created', 'equal', '.'],
                       ['In', 'spite', 'of', 'the', 'debate', 'going', 'on', 'for', 'months',
                        'about', 'the', 'photos', 'of', 'Özil', 'with', 'the', 'Turkish',
                        'President', 'Recep', 'Tayyip', 'Erdogan', ',', 'he', 'regrets', 'the',
                        'return', 'of', 'the', '92-match', 'national', 'player', 'Özil', '.']]
    gt_de_tokenized = [['Goethe', 'stammte', 'aus', 'einer', 'angesehenen', 'bürgerlichen',
                        'Familie', ';', 'sein', 'Großvater', 'mütterlicherseits', 'war', 'als',
                        'Stadtschultheiß', 'höchster', 'Justizbeamter', 'der', 'Stadt', 'Frankfurt',
                        ',', 'sein', 'Vater', 'Doktor', 'der', 'Rechte', 'und', 'kaiserlicher',
                        'Rat', '.'],
                       ['"', 'Das', 'ist', 'eine', 'Frage', ',', 'die', 'natürlich', 'davon',
                        'abhängt', ',', 'dass', 'man', 'einmal', 'ins', 'Gespräch', 'kommt', ',',
                        'dass', 'man', 'mit', 'ihm', 'auch', 'darüber', 'spricht', ',', 'warum',
                        'er', 'das', 'eine', 'oder', 'andere', 'offenbar', 'so', 'empfunden', 'hat',
                        ',', 'wie', 'das', 'in', 'seinem', 'Statement', 'niedergelegt', 'ist', '"',
                        ',', 'sagte', 'Grindel', 'im', 'Fußball-Podcast', '"', 'Phrasenmäher', '"',
                        'der', '"', 'Bild-Zeitung', '.']]
    verify_encode_token(en_tokenizer, EN_SAMPLES, gt_en_tokenized)
    verify_encode_token(de_tokenizer, DE_SAMPLES, gt_de_tokenized)
    vocab = Vocab(collections.Counter(sum(gt_en_tokenized + gt_de_tokenized, [])))
    en_tokenizer.set_vocab(vocab)
    de_tokenizer.set_vocab(vocab)
    verify_pickleble(en_tokenizer, SpacyTokenizer)
    verify_pickleble(de_tokenizer, SpacyTokenizer)
    verify_encode_token_with_offsets(en_tokenizer, EN_SAMPLES)
    verify_encode_token_with_offsets(de_tokenizer, DE_SAMPLES)
Ejemplo n.º 4
0
def test_load_embeddings():
    text_data = [
        'hello', 'world', 'hello', 'nice', 'world', 'hi', 'world', 'sadgood'
    ]
    counter = collections.Counter(text_data)
    vocab1 = Vocab(counter)
    # load with vocab
    matrix1 = load_embeddings(vocab1)
    assert len(matrix1) == len(vocab1)
    # load without vocab
    matrix2, vocab2 = load_embeddings()
    assert len(matrix2) == len(vocab2)
    np.testing.assert_almost_equal(matrix1[vocab1["hello"]],
                                   matrix2[vocab2["hello"]])

    # test_unk_method
    def simple(words):
        return np.ones((len(words), 50))

    matrix3 = load_embeddings(vocab1, unk_method=simple)
    assert sum(matrix3[vocab1['sadgood']] == 1) == matrix3.shape[-1]
    np.testing.assert_almost_equal(matrix3[vocab1["hello"]],
                                   matrix2[vocab2["hello"]])

    # load txt
    with tempfile.TemporaryDirectory() as root:
        path = os.path.join(root, "tmp.txt")
        with open(path, "w") as f:
            f.write("{} {}\n".format(matrix1.shape[0], matrix1.shape[1]))
            for word, vec in zip(vocab1.all_tokens, matrix1):
                f.write(word + " ")
                f.write(" ".join([str(num) for num in vec.tolist()]))
                f.write("\n")
        matrix4 = load_embeddings(vocab1, path)
        np.testing.assert_almost_equal(matrix4, matrix1)
Ejemplo n.º 5
0
def test_get_fasttext_model():
    text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
    counter = collections.Counter(text_data)
    vocab1 = Vocab(counter)
    matrix1 = load_embeddings(vocab1, 'wiki.en')
    ft = get_fasttext_model('wiki.en')
    np.testing.assert_almost_equal(matrix1[vocab1["hello"]], ft['hello'], decimal=4)
    with pytest.raises(ValueError):
        get_fasttext_model('wiki.multi.ar')
Ejemplo n.º 6
0
def test_jieba_tokenizer():
    tokenizer = JiebaTokenizer()
    gt_zh_tokenized = [['苟活', '者', '在', '淡红', '的', '血色', '中', ',',
                        '会', '依稀', '看见', '微茫', '的', '希望', ';', '真的',
                        '猛士', ',', '将', '更奋', '然而', '前行', '。'],
                       ['参加', '工作', ',', '哈尔滨工业大学', '无线电', '工程系', '电子仪器',
                        '及', '测量', '技术', '专业', '毕业', '。']]
    verify_encode_token(tokenizer, ZH_SAMPLES, gt_zh_tokenized)
    verify_decode(tokenizer, ZH_SAMPLES, str)
    vocab = Vocab(collections.Counter(sum(gt_zh_tokenized, [])))
    verify_decode_no_vocab_raise(tokenizer)
    tokenizer.set_vocab(vocab)
    verify_decode(tokenizer, ZH_SAMPLES, int)
    verify_pickleble(tokenizer, JiebaTokenizer)
Ejemplo n.º 7
0
def test_whitespace_tokenizer():
    tokenizer = WhitespaceTokenizer()
    gt_en_tokenized = [['Four', 'score', 'and', 'seven', 'years', 'ago', 'our', 'fathers', 'brought',
                        'forth', 'on', 'this', 'continent,', 'a', 'new', 'nation,', 'conceived',
                        'in', 'Liberty,', 'and', 'dedicated', 'to', 'the', 'proposition', 'that',
                        'all', 'men', 'are', 'created', 'equal.'],
                       ['In', 'spite', 'of', 'the', 'debate', 'going', 'on', 'for', 'months',
                        'about', 'the', 'photos', 'of', 'Özil', 'with', 'the', 'Turkish',
                        'President', 'Recep', 'Tayyip', 'Erdogan,', 'he', 'regrets', 'the',
                        'return', 'of', 'the', '92-match', 'national', 'player', 'Özil.']]
    gt_de_tokenized = [['Goethe', 'stammte', 'aus', 'einer', 'angesehenen', 'bürgerlichen',
                        'Familie;', 'sein', 'Großvater', 'mütterlicherseits', 'war', 'als',
                        'Stadtschultheiß', 'höchster', 'Justizbeamter', 'der', 'Stadt',
                        'Frankfurt,', 'sein', 'Vater', 'Doktor', 'der', 'Rechte', 'und',
                        'kaiserlicher', 'Rat.'],
                       ['"Das', 'ist', 'eine', 'Frage,', 'die', 'natürlich', 'davon', 'abhängt,',
                        'dass', 'man', 'einmal', 'ins', 'Gespräch', 'kommt,', 'dass', 'man', 'mit',
                        'ihm', 'auch', 'darüber', 'spricht,', 'warum', 'er', 'das', 'eine', 'oder',
                        'andere', 'offenbar', 'so', 'empfunden', 'hat,', 'wie', 'das', 'in',
                        'seinem', 'Statement', 'niedergelegt', 'ist",', 'sagte', 'Grindel', 'im',
                        'Fußball-Podcast', '"Phrasenmäher"', 'der', '"Bild-Zeitung.']]
    for _ in range(2):
        # Inject noise and test for encode
        noisy_en_samples = [random_inject_space(ele) for ele in EN_SAMPLES]
        noisy_de_samples = [random_inject_space(ele) for ele in DE_SAMPLES]
        verify_encode_token(tokenizer, noisy_en_samples + noisy_de_samples,
                            gt_en_tokenized + gt_de_tokenized)
        # Test for decode
        verify_decode(tokenizer, EN_SAMPLES + DE_SAMPLES, str)
        # Test for encode_with_offsets
        verify_encode_token_with_offsets(tokenizer, noisy_en_samples + noisy_de_samples)
    verify_decode_no_vocab_raise(tokenizer)

    # Test for output_type = int
    vocab = Vocab(collections.Counter(sum(gt_en_tokenized + gt_de_tokenized,
                                          [])))
    tokenizer.set_vocab(vocab)
    verify_decode(tokenizer, EN_SAMPLES + DE_SAMPLES, int)
    verify_pickleble(tokenizer, WhitespaceTokenizer)
    verify_encode_token_with_offsets(tokenizer, EN_SAMPLES + DE_SAMPLES)
Ejemplo n.º 8
0
def main(args):
    corpus_path_list = args.corpus
    if args.save_dir is None:
        args.save_dir = args.model
    for corpus_path in corpus_path_list:
        if not os.path.exists(corpus_path):
            raise ValueError(
                'The path="{}" provided by --corpus does not exist!'.format(
                    corpus_path))
    print('Learn the "{}"s subword model based on {}.'.format(
        args.model, args.corpus))
    os.makedirs(args.save_dir, exist_ok=True)
    model_prefix = os.path.join(args.save_dir, args.model)
    print('Save the subword model to {}.model'.format(model_prefix))
    print('Save the vocabulary to {}.vocab'.format(model_prefix))
    print()
    print('------- Start Training -------------')
    special_tokens_kv = OrderedDict()
    if not args.disable_unk:
        special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN
    if not args.disable_bos:
        special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN
    if not args.disable_eos:
        special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN
    if not args.disable_pad:
        special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN
    # split custom special tokens
    if args.model in ['yttm'] and len(args.custom_special_tokens) > 0:
        raise ValueError(
            'model {} do not support custom_special_tokens'.format(args.model))
    additional_custom_special_token = OrderedDict()
    for custom_special_token in args.custom_special_tokens:
        kv = custom_special_token.split('=')
        if not len(kv) == 2:
            raise ValueError(
                'parameter {} has wrong format'.format(custom_special_token))
        k, v = kv[0], kv[1]
        if k in special_tokens_kv:
            warnings.warn(
                f'There are overlaps between the custom special tokens and the'
                f' unk, bos, eos, pad tokens. Currently, we will overwrite the '
                f'default tokens. We will overwrite "{k}" to "{v}"')
        special_tokens_kv[k] = v
        additional_custom_special_token[k] = v
    if args.model == 'hf_wordpiece':
        tokenizers = try_import_huggingface_tokenizers()
        if 'unk_token' not in special_tokens_kv or special_tokens_kv[
                'unk_token'] != '[UNK]':
            # TODO, HF Tokenizer must have the unk token.
            special_tokens_kv['unk_token'] = '[UNK]'
        if parse_version(tokenizers.__version__) < parse_version('0.8'):
            # The older version of Tokenizers
            # hf_wordpiece must contain mask, cls and sep tokens
            # the custom defined mask,cls,sep can overwrite the default settings
            if 'mask_token' not in special_tokens_kv:
                special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN
            if 'cls_token' not in special_tokens_kv:
                special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN
            if 'sep_token' not in special_tokens_kv:
                special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN
    special_tokens = list(special_tokens_kv.values())
    print('special tokens: ' + ', '.join(special_tokens))
    vocab = []
    if args.model == 'spm':
        try_import_sentencepiece()
        import sentencepiece as spm
        corpus_path = ','.join(corpus_path_list)
        script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \
                 .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size)
        script += (' --unk_id=' +
                   str(list(special_tokens_kv.keys()).index('unk_token')))
        script += (' --bos_id=' + ('-1' if args.disable_bos else str(
            list(special_tokens_kv.keys()).index('bos_token'))))
        script += (' --eos_id=' + ('-1' if args.disable_eos else str(
            list(special_tokens_kv.keys()).index('eos_token'))))
        script += (' --pad_id=' + ('-1' if args.disable_pad else str(
            list(special_tokens_kv.keys()).index('pad_token'))))
        if len(additional_custom_special_token) > 0:
            script += (
                ' --control_symbols=' +
                ','.join(list(additional_custom_special_token.values())))
        print(script)
        spm.SentencePieceTrainer.Train(script)
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<s>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '</s>'
        # build spm vocab
        spm_model = spm.SentencePieceProcessor()
        spm_model.load(model_prefix + '.model')
        vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))]
        os.remove(model_prefix + '.vocab')
    elif args.model == 'subword_nmt':
        try_import_subword_nmt()
        from subword_nmt import learn_bpe
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        # build model
        with open(corpus_path, 'r', encoding='utf-8') as fc,\
             open(model_prefix + '.model', 'w', encoding='utf-8') as fm:
            learn_bpe.learn_bpe(fc,
                                fm,
                                args.vocab_size - len(special_tokens),
                                total_symbols=True)
        # build vocab
        with open(corpus_path, 'r', encoding='utf-8') as fc, \
             open(model_prefix + '.model', 'r', encoding='utf-8') as fm:
            vocab.extend(special_tokens)
            uniq_chars_internal = set()
            uniq_chars_final = set()
            uniq_words = set()
            for line in fc:
                for word in line.strip('\r\n ').split(' '):
                    if word:
                        uniq_words.add(word)
            # this code piece is same as
            # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows
            uniq_words = [
                tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words
            ]
            for word in uniq_words:
                for char in word[:-1]:
                    uniq_chars_internal.add(char)
                uniq_chars_final.add(word[-1])
            # sort to ensure the same settings produce the same vocab
            vocab.extend(sorted(list(uniq_chars_internal)))
            vocab.extend(sorted(list(uniq_chars_final)))
            fm.readline()
            pair = fm.readline()
            while pair:
                vocab.append(pair.replace(' ', '', 1).strip())
                pair = fm.readline()
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model == 'yttm':
        try_import_yttm()
        import youtokentome as yttm
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        tokenizer = yttm.BPE.train(
            data=corpus_path,
            model=model_prefix + '.model',
            vocab_size=args.vocab_size,
            coverage=args.coverage,
            n_threads=args.n_threads,
            unk_id=special_tokens.index(Vocab.UNK_TOKEN),
            bos_id=-1
            if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN),
            eos_id=-1
            if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN),
            pad_id=-1
            if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN))
        vocab = tokenizer.vocab()
        if 'unk_token' in special_tokens_kv:
            special_tokens_kv['unk_token'] = '<UNK>'
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<BOS>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '<EOS>'
        if 'pad_token' in special_tokens_kv:
            special_tokens_kv['pad_token'] = '<PAD>'
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']:
        tokenizers = try_import_huggingface_tokenizers()
        if args.model == 'hf_bpe':
            split_on_whitespace_only = not args.split_punctuation
            tokenizer = tokenizers.CharBPETokenizer(
                lowercase=args.lowercase,
                bert_normalizer=args.bert_normalizer,
                split_on_whitespace_only=split_on_whitespace_only)
        elif args.model == 'hf_bytebpe':
            tokenizer = tokenizers.ByteLevelBPETokenizer(
                lowercase=args.lowercase)
        elif args.model == 'hf_wordpiece':
            unk_token = special_tokens_kv.get('unk_token', None)
            sep_token = special_tokens_kv.get('sep_token', None)
            cls_token = special_tokens_kv.get('cls_token', None)
            pad_token = special_tokens_kv.get('pad_token', None)
            mask_token = special_tokens_kv.get('mask_token', None)
            if args.bert_normalizer:
                strip_accents = None
                clean_text = True
                handle_chinese_chars = True
            else:
                strip_accents = False
                clean_text = False
                handle_chinese_chars = False
            tokenizer = tokenizers.BertWordPieceTokenizer(
                unk_token=unk_token,
                sep_token=sep_token,
                cls_token=cls_token,
                pad_token=pad_token,
                mask_token=mask_token,
                lowercase=args.lowercase,
                strip_accents=strip_accents,
                handle_chinese_chars=handle_chinese_chars,
                clean_text=clean_text)
        else:
            raise NotImplementedError
        tokenizer.train(corpus_path_list,
                        vocab_size=args.vocab_size,
                        show_progress=True,
                        special_tokens=special_tokens)
        # Deal with the API change of tokenizers >= 0.8
        if version.parse(tokenizers.__version__) >= version.parse('0.8'):
            save_model_path = model_prefix + '.model'
            tokenizer.save(save_model_path)
            model_info = json.load(open(save_model_path, encoding='utf-8'))
            special_tokens_in_tokenizer = model_info['added_tokens']
            assert len(special_tokens_in_tokenizer) == len(special_tokens)
            hf_vocab = model_info['model']['vocab']
            hf_vocab_sorted = sorted(list(hf_vocab.items()),
                                     key=lambda x: x[1])
            hf_vocab_ids = [ele[1] for ele in hf_vocab_sorted]
            assert min(hf_vocab_ids) == 0 and max(
                hf_vocab_ids) == len(hf_vocab_ids) - 1
            vocab = [ele[0] for ele in hf_vocab_sorted]
        else:
            tokenizer.save(args.save_dir, args.model)
            # we replace the huggingface vocab file with our Vocab implementation
            if args.model == 'hf_wordpiece':
                hf_vocab_file = model_prefix + '-vocab.txt'
                with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                    for line in fv:
                        vocab.append(line.strip())
            else:
                # Move the hf_${model}-merges.txt to hf_${model}.models
                os.rename(
                    os.path.join(args.save_dir,
                                 '{}-merges.txt'.format(args.model)),
                    os.path.join(args.save_dir, '{}.model'.format(args.model)))
                hf_vocab_file = model_prefix + '-vocab.json'
                with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                    vocab_kv = json.load(fv)
                    vocab_kv = sorted(list(vocab_kv.items()),
                                      key=lambda x: x[1])
                    for kv in vocab_kv:
                        vocab.append(kv[0])
            os.remove(hf_vocab_file)
    else:
        raise NotImplementedError
    vocab_obj = Vocab(vocab, **special_tokens_kv)
    vocab_obj.save(model_prefix + '.vocab')
    print('-------- Done Training -------------')
Ejemplo n.º 9
0
def main(args):
    corpus_path_list = args.corpus
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_prefix = os.path.join(args.save_dir, args.model)
    special_tokens_kv = OrderedDict()
    # unk is always required
    special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN
    if not args.disable_bos:
        special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN
    if not args.disable_eos:
        special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN
    if not args.disable_pad:
        special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN
    # split custom special tokens
    if args.model in ['yttm'] and len(args.custom_special_tokens) > 0:
        raise ValueError(
            'model {} do not support custom_special_tokens'.format(args.model))
    for custom_special_token in args.custom_special_tokens:
        kv = custom_special_token.split('=')
        if not len(kv) == 2:
            raise ValueError(
                'parameter {} has wrong format'.format(custom_special_token))
        k, v = kv[0], kv[1]
        if k in special_tokens_kv:
            raise ValueError(
                'There are overlaps between the custom special tokens and the'
                ' unk, bos, eos, pad tokens')
        special_tokens_kv[k] = v
    # hf_wordpiece must contains mask, cls and sep tokens
    # the costom defined mask,cls,sep can overwrite the default settings
    if args.model == 'hf_wordpiece':
        if 'mask_token' not in special_tokens_kv:
            special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN
        if 'cls_token' not in special_tokens_kv:
            special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN
        if 'sep_token' not in special_tokens_kv:
            special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN
    special_tokens = list(special_tokens_kv.values())
    print('special tokens: ' + ', '.join(special_tokens))
    vocab = []
    if args.model == 'spm':
        try_import_sentencepiece()
        import sentencepiece as spm
        corpus_path = ','.join(corpus_path_list)
        script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \
                 .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size)
        script += (' --unk_id=' + str(special_tokens.index(Vocab.UNK_TOKEN)))
        script += (' --bos_id=' + ('-1' if args.disable_bos else str(
            special_tokens.index(Vocab.BOS_TOKEN))))
        script += (' --eos_id=' + ('-1' if args.disable_eos else str(
            special_tokens.index(Vocab.EOS_TOKEN))))
        script += (' --pad_id=' + ('-1' if args.disable_pad else str(
            special_tokens.index(Vocab.PAD_TOKEN))))
        if len(args.custom_special_tokens) > 0:
            ids_in_script = script.count('_id')
            script += (' --control_symbols=' +
                       ','.join(special_tokens[ids_in_script:]))
        print(script)
        spm.SentencePieceTrainer.Train(script)
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<s>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '</s>'
        # build spm vocab
        spm_model = spm.SentencePieceProcessor()
        spm_model.load(model_prefix + '.model')
        vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))]
        os.remove(model_prefix + '.vocab')
    elif args.model == 'subword_nmt':
        try_import_subword_nmt()
        from subword_nmt import learn_bpe
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        # build model
        with open(corpus_path, 'r', encoding='utf-8') as fc,\
             open(model_prefix + '.model', 'w', encoding='utf-8') as fm:
            learn_bpe.learn_bpe(fc,
                                fm,
                                args.vocab_size - len(special_tokens),
                                total_symbols=True)
        # build vocab
        with open(corpus_path, 'r', encoding='utf-8') as fc, \
             open(model_prefix + '.model', 'r', encoding='utf-8') as fm:
            vocab.extend(special_tokens)
            uniq_chars_internal = set()
            uniq_chars_final = set()
            uniq_words = set()
            for line in fc:
                for word in line.strip('\r\n ').split(' '):
                    if word:
                        uniq_words.add(word)
            # this code piece is same as
            # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows
            uniq_words = [
                tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words
            ]
            for word in uniq_words:
                for char in word[:-1]:
                    uniq_chars_internal.add(char)
                uniq_chars_final.add(word[-1])
            # sort to ensure the same settings produce the same vocab
            vocab.extend(sorted(list(uniq_chars_internal)))
            vocab.extend(sorted(list(uniq_chars_final)))
            fm.readline()
            pair = fm.readline()
            while (pair):
                vocab.append(pair.replace(' ', '', 1).strip())
                pair = fm.readline()
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model == 'yttm':
        try_import_yttm()
        import youtokentome as yttm
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        tokenizer = yttm.BPE.train(
            data=corpus_path,
            model=model_prefix + '.model',
            vocab_size=args.vocab_size,
            coverage=args.coverage,
            n_threads=args.n_threads,
            unk_id=special_tokens.index(Vocab.UNK_TOKEN),
            bos_id=-1
            if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN),
            eos_id=-1
            if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN),
            pad_id=-1
            if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN))
        vocab = tokenizer.vocab()
        if 'unk_token' in special_tokens_kv:
            special_tokens_kv['unk_token'] = '<UNK>'
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<BOS>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '<EOS>'
        if 'pad_token' in special_tokens_kv:
            special_tokens_kv['pad_token'] = '<PAD>'
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']:
        tokenizers = try_import_huggingface_tokenizers()
        if args.model == 'hf_bpe':
            tokenizer = tokenizers.CharBPETokenizer(lowercase=args.lowercase)
        elif args.model == 'hf_bytebpe':
            tokenizer = tokenizers.ByteLevelBPETokenizer(
                lowercase=args.lowercase)
        elif args.model == 'hf_wordpiece':
            tokenizer = tokenizers.BertWordPieceTokenizer(
                lowercase=args.lowercase, strip_accents=args.strip_accents)
        else:
            raise NotImplementedError
        tokenizer.train(corpus_path_list,
                        vocab_size=args.vocab_size,
                        show_progress=True,
                        special_tokens=special_tokens)
        tokenizer.save(args.save_dir, args.model)
        # we replace the huggingface vocab file with our Vocab implementation
        if args.model == 'hf_wordpiece':
            hf_vocab_file = model_prefix + '-vocab.txt'
            with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                for line in fv:
                    vocab.append(line.strip())
        else:
            # Move the hf_${model}-merges.txt to hf_${model}.models
            os.rename(
                os.path.join(args.save_dir,
                             '{}-merges.txt'.format(args.model)),
                os.path.join(args.save_dir, '{}.model'.format(args.model)))
            hf_vocab_file = model_prefix + '-vocab.json'
            with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                vocab_kv = json.load(fv)
                vocab_kv = sorted(list(vocab_kv.items()), key=lambda x: x[1])
                for kv in vocab_kv:
                    vocab.append(kv[0])
        os.remove(hf_vocab_file)
    else:
        raise NotImplementedError
    unk_token = special_tokens_kv.pop('unk_token')
    vocab_obj = Vocab(vocab, unk_token=unk_token, **special_tokens_kv)
    vocab_obj.save(model_prefix + '.vocab')