def main(args): corpus_path_list = args.corpus if args.save_dir is None: args.save_dir = args.model for corpus_path in corpus_path_list: if not os.path.exists(corpus_path): raise ValueError( 'The path="{}" provided by --corpus does not exist!'.format( corpus_path)) print('Learn the "{}"s subword model based on {}.'.format( args.model, args.corpus)) os.makedirs(args.save_dir, exist_ok=True) model_prefix = os.path.join(args.save_dir, args.model) print('Save the subword model to {}.model'.format(model_prefix)) print('Save the vocabulary to {}.vocab'.format(model_prefix)) print() print('------- Start Training -------------') special_tokens_kv = OrderedDict() if not args.disable_unk: special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN if not args.disable_bos: special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN if not args.disable_eos: special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN if not args.disable_pad: special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN # split custom special tokens if args.model in ['yttm'] and len(args.custom_special_tokens) > 0: raise ValueError( 'model {} do not support custom_special_tokens'.format(args.model)) additional_custom_special_token = OrderedDict() for custom_special_token in args.custom_special_tokens: kv = custom_special_token.split('=') if not len(kv) == 2: raise ValueError( 'parameter {} has wrong format'.format(custom_special_token)) k, v = kv[0], kv[1] if k in special_tokens_kv: warnings.warn( f'There are overlaps between the custom special tokens and the' f' unk, bos, eos, pad tokens. Currently, we will overwrite the ' f'default tokens. We will overwrite "{k}" to "{v}"') special_tokens_kv[k] = v additional_custom_special_token[k] = v if args.model == 'hf_wordpiece': tokenizers = try_import_huggingface_tokenizers() if 'unk_token' not in special_tokens_kv or special_tokens_kv[ 'unk_token'] != '[UNK]': # TODO, HF Tokenizer must have the unk token. special_tokens_kv['unk_token'] = '[UNK]' if parse_version(tokenizers.__version__) < parse_version('0.8'): # The older version of Tokenizers # hf_wordpiece must contain mask, cls and sep tokens # the custom defined mask,cls,sep can overwrite the default settings if 'mask_token' not in special_tokens_kv: special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN if 'cls_token' not in special_tokens_kv: special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN if 'sep_token' not in special_tokens_kv: special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN special_tokens = list(special_tokens_kv.values()) print('special tokens: ' + ', '.join(special_tokens)) vocab = [] if args.model == 'spm': try_import_sentencepiece() import sentencepiece as spm corpus_path = ','.join(corpus_path_list) script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \ .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size) script += (' --unk_id=' + str(list(special_tokens_kv.keys()).index('unk_token'))) script += (' --bos_id=' + ('-1' if args.disable_bos else str( list(special_tokens_kv.keys()).index('bos_token')))) script += (' --eos_id=' + ('-1' if args.disable_eos else str( list(special_tokens_kv.keys()).index('eos_token')))) script += (' --pad_id=' + ('-1' if args.disable_pad else str( list(special_tokens_kv.keys()).index('pad_token')))) if len(additional_custom_special_token) > 0: script += ( ' --control_symbols=' + ','.join(list(additional_custom_special_token.values()))) print(script) spm.SentencePieceTrainer.Train(script) if 'bos_token' in special_tokens_kv: special_tokens_kv['bos_token'] = '<s>' if 'eos_token' in special_tokens_kv: special_tokens_kv['eos_token'] = '</s>' # build spm vocab spm_model = spm.SentencePieceProcessor() spm_model.load(model_prefix + '.model') vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))] os.remove(model_prefix + '.vocab') elif args.model == 'subword_nmt': try_import_subword_nmt() from subword_nmt import learn_bpe corpus_path = cat_corpus(corpus_path_list)\ if len(corpus_path_list) > 1 else corpus_path_list[0] # build model with open(corpus_path, 'r', encoding='utf-8') as fc,\ open(model_prefix + '.model', 'w', encoding='utf-8') as fm: learn_bpe.learn_bpe(fc, fm, args.vocab_size - len(special_tokens), total_symbols=True) # build vocab with open(corpus_path, 'r', encoding='utf-8') as fc, \ open(model_prefix + '.model', 'r', encoding='utf-8') as fm: vocab.extend(special_tokens) uniq_chars_internal = set() uniq_chars_final = set() uniq_words = set() for line in fc: for word in line.strip('\r\n ').split(' '): if word: uniq_words.add(word) # this code piece is same as # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows uniq_words = [ tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words ] for word in uniq_words: for char in word[:-1]: uniq_chars_internal.add(char) uniq_chars_final.add(word[-1]) # sort to ensure the same settings produce the same vocab vocab.extend(sorted(list(uniq_chars_internal))) vocab.extend(sorted(list(uniq_chars_final))) fm.readline() pair = fm.readline() while pair: vocab.append(pair.replace(' ', '', 1).strip()) pair = fm.readline() if len(corpus_path_list) > 1: os.remove(corpus_path) elif args.model == 'yttm': try_import_yttm() import youtokentome as yttm corpus_path = cat_corpus(corpus_path_list)\ if len(corpus_path_list) > 1 else corpus_path_list[0] tokenizer = yttm.BPE.train( data=corpus_path, model=model_prefix + '.model', vocab_size=args.vocab_size, coverage=args.coverage, n_threads=args.n_threads, unk_id=special_tokens.index(Vocab.UNK_TOKEN), bos_id=-1 if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN), eos_id=-1 if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN), pad_id=-1 if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN)) vocab = tokenizer.vocab() if 'unk_token' in special_tokens_kv: special_tokens_kv['unk_token'] = '<UNK>' if 'bos_token' in special_tokens_kv: special_tokens_kv['bos_token'] = '<BOS>' if 'eos_token' in special_tokens_kv: special_tokens_kv['eos_token'] = '<EOS>' if 'pad_token' in special_tokens_kv: special_tokens_kv['pad_token'] = '<PAD>' if len(corpus_path_list) > 1: os.remove(corpus_path) elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']: tokenizers = try_import_huggingface_tokenizers() if args.model == 'hf_bpe': split_on_whitespace_only = not args.split_punctuation tokenizer = tokenizers.CharBPETokenizer( lowercase=args.lowercase, bert_normalizer=args.bert_normalizer, split_on_whitespace_only=split_on_whitespace_only) elif args.model == 'hf_bytebpe': tokenizer = tokenizers.ByteLevelBPETokenizer( lowercase=args.lowercase) elif args.model == 'hf_wordpiece': unk_token = special_tokens_kv.get('unk_token', None) sep_token = special_tokens_kv.get('sep_token', None) cls_token = special_tokens_kv.get('cls_token', None) pad_token = special_tokens_kv.get('pad_token', None) mask_token = special_tokens_kv.get('mask_token', None) if args.bert_normalizer: strip_accents = None clean_text = True handle_chinese_chars = True else: strip_accents = False clean_text = False handle_chinese_chars = False tokenizer = tokenizers.BertWordPieceTokenizer( unk_token=unk_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, lowercase=args.lowercase, strip_accents=strip_accents, handle_chinese_chars=handle_chinese_chars, clean_text=clean_text) else: raise NotImplementedError tokenizer.train(corpus_path_list, vocab_size=args.vocab_size, show_progress=True, special_tokens=special_tokens) # Deal with the API change of tokenizers >= 0.8 if version.parse(tokenizers.__version__) >= version.parse('0.8'): save_model_path = model_prefix + '.model' tokenizer.save(save_model_path) model_info = json.load(open(save_model_path, encoding='utf-8')) special_tokens_in_tokenizer = model_info['added_tokens'] assert len(special_tokens_in_tokenizer) == len(special_tokens) hf_vocab = model_info['model']['vocab'] hf_vocab_sorted = sorted(list(hf_vocab.items()), key=lambda x: x[1]) hf_vocab_ids = [ele[1] for ele in hf_vocab_sorted] assert min(hf_vocab_ids) == 0 and max( hf_vocab_ids) == len(hf_vocab_ids) - 1 vocab = [ele[0] for ele in hf_vocab_sorted] else: tokenizer.save(args.save_dir, args.model) # we replace the huggingface vocab file with our Vocab implementation if args.model == 'hf_wordpiece': hf_vocab_file = model_prefix + '-vocab.txt' with open(hf_vocab_file, 'r', encoding='utf-8') as fv: for line in fv: vocab.append(line.strip()) else: # Move the hf_${model}-merges.txt to hf_${model}.models os.rename( os.path.join(args.save_dir, '{}-merges.txt'.format(args.model)), os.path.join(args.save_dir, '{}.model'.format(args.model))) hf_vocab_file = model_prefix + '-vocab.json' with open(hf_vocab_file, 'r', encoding='utf-8') as fv: vocab_kv = json.load(fv) vocab_kv = sorted(list(vocab_kv.items()), key=lambda x: x[1]) for kv in vocab_kv: vocab.append(kv[0]) os.remove(hf_vocab_file) else: raise NotImplementedError vocab_obj = Vocab(vocab, **special_tokens_kv) vocab_obj.save(model_prefix + '.vocab') print('-------- Done Training -------------')
def main(args): corpus_path_list = args.corpus if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model_prefix = os.path.join(args.save_dir, args.model) special_tokens_kv = OrderedDict() # unk is always required special_tokens_kv['unk_token'] = Vocab.UNK_TOKEN if not args.disable_bos: special_tokens_kv['bos_token'] = Vocab.BOS_TOKEN if not args.disable_eos: special_tokens_kv['eos_token'] = Vocab.EOS_TOKEN if not args.disable_pad: special_tokens_kv['pad_token'] = Vocab.PAD_TOKEN # split custom special tokens if args.model in ['yttm'] and len(args.custom_special_tokens) > 0: raise ValueError( 'model {} do not support custom_special_tokens'.format(args.model)) for custom_special_token in args.custom_special_tokens: kv = custom_special_token.split('=') if not len(kv) == 2: raise ValueError( 'parameter {} has wrong format'.format(custom_special_token)) k, v = kv[0], kv[1] if k in special_tokens_kv: raise ValueError( 'There are overlaps between the custom special tokens and the' ' unk, bos, eos, pad tokens') special_tokens_kv[k] = v # hf_wordpiece must contains mask, cls and sep tokens # the costom defined mask,cls,sep can overwrite the default settings if args.model == 'hf_wordpiece': if 'mask_token' not in special_tokens_kv: special_tokens_kv['mask_token'] = Vocab.MASK_TOKEN if 'cls_token' not in special_tokens_kv: special_tokens_kv['cls_token'] = Vocab.CLS_TOKEN if 'sep_token' not in special_tokens_kv: special_tokens_kv['sep_token'] = Vocab.SEP_TOKEN special_tokens = list(special_tokens_kv.values()) print('special tokens: ' + ', '.join(special_tokens)) vocab = [] if args.model == 'spm': try_import_sentencepiece() import sentencepiece as spm corpus_path = ','.join(corpus_path_list) script = '--input={} --model_prefix={} --vocab_size={} --character_coverage={} --input_sentence_size={}' \ .format(corpus_path, model_prefix, args.vocab_size, args.coverage, args.input_sentence_size) script += (' --unk_id=' + str(special_tokens.index(Vocab.UNK_TOKEN))) script += (' --bos_id=' + ('-1' if args.disable_bos else str( special_tokens.index(Vocab.BOS_TOKEN)))) script += (' --eos_id=' + ('-1' if args.disable_eos else str( special_tokens.index(Vocab.EOS_TOKEN)))) script += (' --pad_id=' + ('-1' if args.disable_pad else str( special_tokens.index(Vocab.PAD_TOKEN)))) if len(args.custom_special_tokens) > 0: ids_in_script = script.count('_id') script += (' --control_symbols=' + ','.join(special_tokens[ids_in_script:])) print(script) spm.SentencePieceTrainer.Train(script) if 'bos_token' in special_tokens_kv: special_tokens_kv['bos_token'] = '<s>' if 'eos_token' in special_tokens_kv: special_tokens_kv['eos_token'] = '</s>' # build spm vocab spm_model = spm.SentencePieceProcessor() spm_model.load(model_prefix + '.model') vocab = [spm_model.id_to_piece(i) for i in range(len(spm_model))] os.remove(model_prefix + '.vocab') elif args.model == 'subword_nmt': try_import_subword_nmt() from subword_nmt import learn_bpe corpus_path = cat_corpus(corpus_path_list)\ if len(corpus_path_list) > 1 else corpus_path_list[0] # build model with open(corpus_path, 'r', encoding='utf-8') as fc,\ open(model_prefix + '.model', 'w', encoding='utf-8') as fm: learn_bpe.learn_bpe(fc, fm, args.vocab_size - len(special_tokens), total_symbols=True) # build vocab with open(corpus_path, 'r', encoding='utf-8') as fc, \ open(model_prefix + '.model', 'r', encoding='utf-8') as fm: vocab.extend(special_tokens) uniq_chars_internal = set() uniq_chars_final = set() uniq_words = set() for line in fc: for word in line.strip('\r\n ').split(' '): if word: uniq_words.add(word) # this code piece is same as # https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/learn_bpe.py shows uniq_words = [ tuple(x[:-1]) + (x[-1] + '</w>', ) for x in uniq_words ] for word in uniq_words: for char in word[:-1]: uniq_chars_internal.add(char) uniq_chars_final.add(word[-1]) # sort to ensure the same settings produce the same vocab vocab.extend(sorted(list(uniq_chars_internal))) vocab.extend(sorted(list(uniq_chars_final))) fm.readline() pair = fm.readline() while (pair): vocab.append(pair.replace(' ', '', 1).strip()) pair = fm.readline() if len(corpus_path_list) > 1: os.remove(corpus_path) elif args.model == 'yttm': try_import_yttm() import youtokentome as yttm corpus_path = cat_corpus(corpus_path_list)\ if len(corpus_path_list) > 1 else corpus_path_list[0] tokenizer = yttm.BPE.train( data=corpus_path, model=model_prefix + '.model', vocab_size=args.vocab_size, coverage=args.coverage, n_threads=args.n_threads, unk_id=special_tokens.index(Vocab.UNK_TOKEN), bos_id=-1 if args.disable_bos else special_tokens.index(Vocab.BOS_TOKEN), eos_id=-1 if args.disable_eos else special_tokens.index(Vocab.EOS_TOKEN), pad_id=-1 if args.disable_pad else special_tokens.index(Vocab.PAD_TOKEN)) vocab = tokenizer.vocab() if 'unk_token' in special_tokens_kv: special_tokens_kv['unk_token'] = '<UNK>' if 'bos_token' in special_tokens_kv: special_tokens_kv['bos_token'] = '<BOS>' if 'eos_token' in special_tokens_kv: special_tokens_kv['eos_token'] = '<EOS>' if 'pad_token' in special_tokens_kv: special_tokens_kv['pad_token'] = '<PAD>' if len(corpus_path_list) > 1: os.remove(corpus_path) elif args.model in ['hf_bpe', 'hf_bytebpe', 'hf_wordpiece']: tokenizers = try_import_huggingface_tokenizers() if args.model == 'hf_bpe': tokenizer = tokenizers.CharBPETokenizer(lowercase=args.lowercase) elif args.model == 'hf_bytebpe': tokenizer = tokenizers.ByteLevelBPETokenizer( lowercase=args.lowercase) elif args.model == 'hf_wordpiece': tokenizer = tokenizers.BertWordPieceTokenizer( lowercase=args.lowercase, strip_accents=args.strip_accents) else: raise NotImplementedError tokenizer.train(corpus_path_list, vocab_size=args.vocab_size, show_progress=True, special_tokens=special_tokens) tokenizer.save(args.save_dir, args.model) # we replace the huggingface vocab file with our Vocab implementation if args.model == 'hf_wordpiece': hf_vocab_file = model_prefix + '-vocab.txt' with open(hf_vocab_file, 'r', encoding='utf-8') as fv: for line in fv: vocab.append(line.strip()) else: # Move the hf_${model}-merges.txt to hf_${model}.models os.rename( os.path.join(args.save_dir, '{}-merges.txt'.format(args.model)), os.path.join(args.save_dir, '{}.model'.format(args.model))) hf_vocab_file = model_prefix + '-vocab.json' with open(hf_vocab_file, 'r', encoding='utf-8') as fv: vocab_kv = json.load(fv) vocab_kv = sorted(list(vocab_kv.items()), key=lambda x: x[1]) for kv in vocab_kv: vocab.append(kv[0]) os.remove(hf_vocab_file) else: raise NotImplementedError unk_token = special_tokens_kv.pop('unk_token') vocab_obj = Vocab(vocab, unk_token=unk_token, **special_tokens_kv) vocab_obj.save(model_prefix + '.vocab')