示例#1
0
def main(args):
    corpus_path_list = args.corpus
    if args.save_dir is None:
        args.save_dir = args.model
    for corpus_path in corpus_path_list:
        if not os.path.exists(corpus_path):
            raise ValueError(
                'The path="{}" provided by --corpus does not exist!'.format(
                    corpus_path))
    print('Learn the "{}"s subword model based on {}.'.format(
        args.model, args.corpus))
    os.makedirs(args.save_dir, exist_ok=True)
    model_prefix = os.path.join(args.save_dir, args.model)
    print('Save the subword model to {}.model'.format(model_prefix))
    print('Save the vocabulary to {}.vocab'.format(model_prefix))
    print()
    print('------- Start Training -------------')
    special_tokens_kv = OrderedDict()
    if not args.disable_unk:
        special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN
    if not args.disable_bos:
        special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN
    if not args.disable_eos:
        special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN
    if not args.disable_pad:
        special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN
    # split custom special tokens
    if args.model in ['yttm'] and len(args.custom_special_tokens) > 0:
        raise ValueError(
            'model {} do not support custom_special_tokens'.format(args.model))
    additional_custom_special_token = OrderedDict()
    for custom_special_token in args.custom_special_tokens:
        kv = custom_special_token.split('=')
        if not len(kv) == 2:
            raise ValueError(
                'parameter {} has wrong format'.format(custom_special_token))
        k, v = kv[0], kv[1]
        if k in special_tokens_kv:
            warnings.warn(
                f'There are overlaps between the custom special tokens and the'
                f' unk, bos, eos, pad tokens. Currently, we will overwrite the '
                f'default tokens. We will overwrite "{k}" to "{v}"')
        special_tokens_kv[k] = v
        additional_custom_special_token[k] = v
    if args.model == 'hf_wordpiece':
        tokenizers = try_import_huggingface_tokenizers()
        if 'unk_token' not in special_tokens_kv or special_tokens_kv[
                'unk_token'] != '[UNK]':
            # TODO, HF Tokenizer must have the unk token.
            special_tokens_kv['unk_token'] = '[UNK]'
        if parse_version(tokenizers.__version__) < parse_version('0.8'):
            # The older version of Tokenizers
            # hf_wordpiece must contain mask, cls and sep tokens
            # the custom defined mask,cls,sep can overwrite the default settings
            if 'mask_token' not in special_tokens_kv:
                special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN
            if 'cls_token' not in special_tokens_kv:
                special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN
            if 'sep_token' not in special_tokens_kv:
                special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN
    special_tokens = list(special_tokens_kv.values())
    print('special tokens: ' + ', '.join(special_tokens))
    vocab = []
    if args.model == 'spm':
        try_import_sentencepiece()
        import sentencepiece as spm
        corpus_path = ','.join(corpus_path_list)
        script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \
                 .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size)
        script += (' --unk_id=' +
                   str(list(special_tokens_kv.keys()).index('unk_token')))
        script += (' --bos_id=' + ('-1' if args.disable_bos else str(
            list(special_tokens_kv.keys()).index('bos_token'))))
        script += (' --eos_id=' + ('-1' if args.disable_eos else str(
            list(special_tokens_kv.keys()).index('eos_token'))))
        script += (' --pad_id=' + ('-1' if args.disable_pad else str(
            list(special_tokens_kv.keys()).index('pad_token'))))
        if len(additional_custom_special_token) > 0:
            script += (
                ' --control_symbols=' +
                ','.join(list(additional_custom_special_token.values())))
        print(script)
        spm.SentencePieceTrainer.Train(script)
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<s>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '</s>'
        # build spm vocab
        spm_model = spm.SentencePieceProcessor()
        spm_model.load(model_prefix + '.model')
        vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))]
        os.remove(model_prefix + '.vocab')
    elif args.model == 'subword_nmt':
        try_import_subword_nmt()
        from subword_nmt import learn_bpe
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        # build model
        with open(corpus_path, 'r', encoding='utf-8') as fc,\
             open(model_prefix + '.model', 'w', encoding='utf-8') as fm:
            learn_bpe.learn_bpe(fc,
                                fm,
                                args.vocab_size - len(special_tokens),
                                total_symbols=True)
        # build vocab
        with open(corpus_path, 'r', encoding='utf-8') as fc, \
             open(model_prefix + '.model', 'r', encoding='utf-8') as fm:
            vocab.extend(special_tokens)
            uniq_chars_internal = set()
            uniq_chars_final = set()
            uniq_words = set()
            for line in fc:
                for word in line.strip('\r\n ').split(' '):
                    if word:
                        uniq_words.add(word)
            # this code piece is same as
            # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows
            uniq_words = [
                tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words
            ]
            for word in uniq_words:
                for char in word[:-1]:
                    uniq_chars_internal.add(char)
                uniq_chars_final.add(word[-1])
            # sort to ensure the same settings produce the same vocab
            vocab.extend(sorted(list(uniq_chars_internal)))
            vocab.extend(sorted(list(uniq_chars_final)))
            fm.readline()
            pair = fm.readline()
            while pair:
                vocab.append(pair.replace(' ', '', 1).strip())
                pair = fm.readline()
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model == 'yttm':
        try_import_yttm()
        import youtokentome as yttm
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        tokenizer = yttm.BPE.train(
            data=corpus_path,
            model=model_prefix + '.model',
            vocab_size=args.vocab_size,
            coverage=args.coverage,
            n_threads=args.n_threads,
            unk_id=special_tokens.index(Vocab.UNK_TOKEN),
            bos_id=-1
            if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN),
            eos_id=-1
            if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN),
            pad_id=-1
            if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN))
        vocab = tokenizer.vocab()
        if 'unk_token' in special_tokens_kv:
            special_tokens_kv['unk_token'] = '<UNK>'
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<BOS>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '<EOS>'
        if 'pad_token' in special_tokens_kv:
            special_tokens_kv['pad_token'] = '<PAD>'
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']:
        tokenizers = try_import_huggingface_tokenizers()
        if args.model == 'hf_bpe':
            split_on_whitespace_only = not args.split_punctuation
            tokenizer = tokenizers.CharBPETokenizer(
                lowercase=args.lowercase,
                bert_normalizer=args.bert_normalizer,
                split_on_whitespace_only=split_on_whitespace_only)
        elif args.model == 'hf_bytebpe':
            tokenizer = tokenizers.ByteLevelBPETokenizer(
                lowercase=args.lowercase)
        elif args.model == 'hf_wordpiece':
            unk_token = special_tokens_kv.get('unk_token', None)
            sep_token = special_tokens_kv.get('sep_token', None)
            cls_token = special_tokens_kv.get('cls_token', None)
            pad_token = special_tokens_kv.get('pad_token', None)
            mask_token = special_tokens_kv.get('mask_token', None)
            if args.bert_normalizer:
                strip_accents = None
                clean_text = True
                handle_chinese_chars = True
            else:
                strip_accents = False
                clean_text = False
                handle_chinese_chars = False
            tokenizer = tokenizers.BertWordPieceTokenizer(
                unk_token=unk_token,
                sep_token=sep_token,
                cls_token=cls_token,
                pad_token=pad_token,
                mask_token=mask_token,
                lowercase=args.lowercase,
                strip_accents=strip_accents,
                handle_chinese_chars=handle_chinese_chars,
                clean_text=clean_text)
        else:
            raise NotImplementedError
        tokenizer.train(corpus_path_list,
                        vocab_size=args.vocab_size,
                        show_progress=True,
                        special_tokens=special_tokens)
        # Deal with the API change of tokenizers >= 0.8
        if version.parse(tokenizers.__version__) >= version.parse('0.8'):
            save_model_path = model_prefix + '.model'
            tokenizer.save(save_model_path)
            model_info = json.load(open(save_model_path, encoding='utf-8'))
            special_tokens_in_tokenizer = model_info['added_tokens']
            assert len(special_tokens_in_tokenizer) == len(special_tokens)
            hf_vocab = model_info['model']['vocab']
            hf_vocab_sorted = sorted(list(hf_vocab.items()),
                                     key=lambda x: x[1])
            hf_vocab_ids = [ele[1] for ele in hf_vocab_sorted]
            assert min(hf_vocab_ids) == 0 and max(
                hf_vocab_ids) == len(hf_vocab_ids) - 1
            vocab = [ele[0] for ele in hf_vocab_sorted]
        else:
            tokenizer.save(args.save_dir, args.model)
            # we replace the huggingface vocab file with our Vocab implementation
            if args.model == 'hf_wordpiece':
                hf_vocab_file = model_prefix + '-vocab.txt'
                with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                    for line in fv:
                        vocab.append(line.strip())
            else:
                # Move the hf_${model}-merges.txt to hf_${model}.models
                os.rename(
                    os.path.join(args.save_dir,
                                 '{}-merges.txt'.format(args.model)),
                    os.path.join(args.save_dir, '{}.model'.format(args.model)))
                hf_vocab_file = model_prefix + '-vocab.json'
                with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                    vocab_kv = json.load(fv)
                    vocab_kv = sorted(list(vocab_kv.items()),
                                      key=lambda x: x[1])
                    for kv in vocab_kv:
                        vocab.append(kv[0])
            os.remove(hf_vocab_file)
    else:
        raise NotImplementedError
    vocab_obj = Vocab(vocab, **special_tokens_kv)
    vocab_obj.save(model_prefix + '.vocab')
    print('-------- Done Training -------------')
示例#2
0
def main(args):
    corpus_path_list = args.corpus
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_prefix = os.path.join(args.save_dir, args.model)
    special_tokens_kv = OrderedDict()
    # unk is always required
    special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN
    if not args.disable_bos:
        special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN
    if not args.disable_eos:
        special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN
    if not args.disable_pad:
        special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN
    # split custom special tokens
    if args.model in ['yttm'] and len(args.custom_special_tokens) > 0:
        raise ValueError(
            'model {} do not support custom_special_tokens'.format(args.model))
    for custom_special_token in args.custom_special_tokens:
        kv = custom_special_token.split('=')
        if not len(kv) == 2:
            raise ValueError(
                'parameter {} has wrong format'.format(custom_special_token))
        k, v = kv[0], kv[1]
        if k in special_tokens_kv:
            raise ValueError(
                'There are overlaps between the custom special tokens and the'
                ' unk, bos, eos, pad tokens')
        special_tokens_kv[k] = v
    # hf_wordpiece must contains mask, cls and sep tokens
    # the costom defined mask,cls,sep can overwrite the default settings
    if args.model == 'hf_wordpiece':
        if 'mask_token' not in special_tokens_kv:
            special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN
        if 'cls_token' not in special_tokens_kv:
            special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN
        if 'sep_token' not in special_tokens_kv:
            special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN
    special_tokens = list(special_tokens_kv.values())
    print('special tokens: ' + ', '.join(special_tokens))
    vocab = []
    if args.model == 'spm':
        try_import_sentencepiece()
        import sentencepiece as spm
        corpus_path = ','.join(corpus_path_list)
        script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \
                 .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size)
        script += (' --unk_id=' + str(special_tokens.index(Vocab.UNK_TOKEN)))
        script += (' --bos_id=' + ('-1' if args.disable_bos else str(
            special_tokens.index(Vocab.BOS_TOKEN))))
        script += (' --eos_id=' + ('-1' if args.disable_eos else str(
            special_tokens.index(Vocab.EOS_TOKEN))))
        script += (' --pad_id=' + ('-1' if args.disable_pad else str(
            special_tokens.index(Vocab.PAD_TOKEN))))
        if len(args.custom_special_tokens) > 0:
            ids_in_script = script.count('_id')
            script += (' --control_symbols=' +
                       ','.join(special_tokens[ids_in_script:]))
        print(script)
        spm.SentencePieceTrainer.Train(script)
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<s>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '</s>'
        # build spm vocab
        spm_model = spm.SentencePieceProcessor()
        spm_model.load(model_prefix + '.model')
        vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))]
        os.remove(model_prefix + '.vocab')
    elif args.model == 'subword_nmt':
        try_import_subword_nmt()
        from subword_nmt import learn_bpe
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        # build model
        with open(corpus_path, 'r', encoding='utf-8') as fc,\
             open(model_prefix + '.model', 'w', encoding='utf-8') as fm:
            learn_bpe.learn_bpe(fc,
                                fm,
                                args.vocab_size - len(special_tokens),
                                total_symbols=True)
        # build vocab
        with open(corpus_path, 'r', encoding='utf-8') as fc, \
             open(model_prefix + '.model', 'r', encoding='utf-8') as fm:
            vocab.extend(special_tokens)
            uniq_chars_internal = set()
            uniq_chars_final = set()
            uniq_words = set()
            for line in fc:
                for word in line.strip('\r\n ').split(' '):
                    if word:
                        uniq_words.add(word)
            # this code piece is same as
            # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows
            uniq_words = [
                tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words
            ]
            for word in uniq_words:
                for char in word[:-1]:
                    uniq_chars_internal.add(char)
                uniq_chars_final.add(word[-1])
            # sort to ensure the same settings produce the same vocab
            vocab.extend(sorted(list(uniq_chars_internal)))
            vocab.extend(sorted(list(uniq_chars_final)))
            fm.readline()
            pair = fm.readline()
            while (pair):
                vocab.append(pair.replace(' ', '', 1).strip())
                pair = fm.readline()
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model == 'yttm':
        try_import_yttm()
        import youtokentome as yttm
        corpus_path = cat_corpus(corpus_path_list)\
            if len(corpus_path_list) > 1 else corpus_path_list[0]
        tokenizer = yttm.BPE.train(
            data=corpus_path,
            model=model_prefix + '.model',
            vocab_size=args.vocab_size,
            coverage=args.coverage,
            n_threads=args.n_threads,
            unk_id=special_tokens.index(Vocab.UNK_TOKEN),
            bos_id=-1
            if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN),
            eos_id=-1
            if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN),
            pad_id=-1
            if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN))
        vocab = tokenizer.vocab()
        if 'unk_token' in special_tokens_kv:
            special_tokens_kv['unk_token'] = '<UNK>'
        if 'bos_token' in special_tokens_kv:
            special_tokens_kv['bos_token'] = '<BOS>'
        if 'eos_token' in special_tokens_kv:
            special_tokens_kv['eos_token'] = '<EOS>'
        if 'pad_token' in special_tokens_kv:
            special_tokens_kv['pad_token'] = '<PAD>'
        if len(corpus_path_list) > 1:
            os.remove(corpus_path)
    elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']:
        tokenizers = try_import_huggingface_tokenizers()
        if args.model == 'hf_bpe':
            tokenizer = tokenizers.CharBPETokenizer(lowercase=args.lowercase)
        elif args.model == 'hf_bytebpe':
            tokenizer = tokenizers.ByteLevelBPETokenizer(
                lowercase=args.lowercase)
        elif args.model == 'hf_wordpiece':
            tokenizer = tokenizers.BertWordPieceTokenizer(
                lowercase=args.lowercase, strip_accents=args.strip_accents)
        else:
            raise NotImplementedError
        tokenizer.train(corpus_path_list,
                        vocab_size=args.vocab_size,
                        show_progress=True,
                        special_tokens=special_tokens)
        tokenizer.save(args.save_dir, args.model)
        # we replace the huggingface vocab file with our Vocab implementation
        if args.model == 'hf_wordpiece':
            hf_vocab_file = model_prefix + '-vocab.txt'
            with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                for line in fv:
                    vocab.append(line.strip())
        else:
            # Move the hf_${model}-merges.txt to hf_${model}.models
            os.rename(
                os.path.join(args.save_dir,
                             '{}-merges.txt'.format(args.model)),
                os.path.join(args.save_dir, '{}.model'.format(args.model)))
            hf_vocab_file = model_prefix + '-vocab.json'
            with open(hf_vocab_file, 'r', encoding='utf-8') as fv:
                vocab_kv = json.load(fv)
                vocab_kv = sorted(list(vocab_kv.items()), key=lambda x: x[1])
                for kv in vocab_kv:
                    vocab.append(kv[0])
        os.remove(hf_vocab_file)
    else:
        raise NotImplementedError
    unk_token = special_tokens_kv.pop('unk_token')
    vocab_obj = Vocab(vocab, unk_token=unk_token, **special_tokens_kv)
    vocab_obj.save(model_prefix + '.vocab')