def _load_vocab(dataset_name, vocab, root): if dataset_name: if vocab is not None: warnings.warn('Both dataset_name and vocab are specified. Loading vocab for dataset. ' 'Input "vocab" argument will be ignored.') vocab = _load_pretrained_vocab(dataset_name, root) else: assert vocab is not None, 'Must specify vocab if not loading from predefined datasets.' return vocab
def fairseq_vocab_to_gluon_vocab(torch_vocab): index_to_words = [None] * len(torch_vocab) bos_idx = torch_vocab.bos() pad_idx = torch_vocab.pad() eos_idx = torch_vocab.eos() unk_idx = torch_vocab.unk() index_to_words[bos_idx] = torch_vocab.symbols[bos_idx] index_to_words[pad_idx] = torch_vocab.symbols[pad_idx] index_to_words[eos_idx] = torch_vocab.symbols[eos_idx] index_to_words[unk_idx] = torch_vocab.symbols[unk_idx] specials = [bos_idx, pad_idx, eos_idx, unk_idx] openai_to_roberta = {} openai_vocab = _load_pretrained_vocab('openai_webtext', '.') with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f: for i, line in enumerate(f): token, count = line.split(' ') try: fake_token = int(token) openai_to_roberta[token] = i + len(specials) except ValueError: index_to_words[i + len(specials)] = token for idx, token in enumerate(openai_vocab.idx_to_token): if str(idx) in openai_to_roberta: index_to_words[openai_to_roberta[str(idx)]] = token else: assert token == u'<mask>', token mask_idx = torch_vocab.index(u'<mask>') index_to_words[mask_idx] = torch_vocab.string([mask_idx]) assert None not in index_to_words word2idx = {} for idx, token in enumerate(index_to_words): word2idx[token] = idx vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx, unknown_token=index_to_words[unk_idx], padding_token=index_to_words[pad_idx], bos_token=index_to_words[bos_idx], eos_token=index_to_words[eos_idx], mask_token=u'<mask>') return vocab