Пример #1
0
def _load_vocab(dataset_name, vocab, root):
    if dataset_name:
        if vocab is not None:
            warnings.warn('Both dataset_name and vocab are specified. Loading vocab for dataset. '
                          'Input "vocab" argument will be ignored.')
        vocab = _load_pretrained_vocab(dataset_name, root)
    else:
        assert vocab is not None, 'Must specify vocab if not loading from predefined datasets.'
    return vocab
Пример #2
0
def _load_vocab(dataset_name, vocab, root):
    if dataset_name:
        if vocab is not None:
            warnings.warn('Both dataset_name and vocab are specified. Loading vocab for dataset. '
                          'Input "vocab" argument will be ignored.')
        vocab = _load_pretrained_vocab(dataset_name, root)
    else:
        assert vocab is not None, 'Must specify vocab if not loading from predefined datasets.'
    return vocab
Пример #3
0
def fairseq_vocab_to_gluon_vocab(torch_vocab):
    index_to_words = [None] * len(torch_vocab)

    bos_idx = torch_vocab.bos()
    pad_idx = torch_vocab.pad()
    eos_idx = torch_vocab.eos()
    unk_idx = torch_vocab.unk()

    index_to_words[bos_idx] = torch_vocab.symbols[bos_idx]
    index_to_words[pad_idx] = torch_vocab.symbols[pad_idx]
    index_to_words[eos_idx] = torch_vocab.symbols[eos_idx]
    index_to_words[unk_idx] = torch_vocab.symbols[unk_idx]

    specials = [bos_idx, pad_idx, eos_idx, unk_idx]

    openai_to_roberta = {}
    openai_vocab = _load_pretrained_vocab('openai_webtext', '.')

    with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f:
        for i, line in enumerate(f):
            token, count = line.split(' ')
            try:
                fake_token = int(token)
                openai_to_roberta[token] = i + len(specials)
            except ValueError:
                index_to_words[i + len(specials)] = token

    for idx, token in enumerate(openai_vocab.idx_to_token):
        if str(idx) in openai_to_roberta:
            index_to_words[openai_to_roberta[str(idx)]] = token
        else:
            assert token == u'<mask>', token

    mask_idx = torch_vocab.index(u'<mask>')
    index_to_words[mask_idx] = torch_vocab.string([mask_idx])
    assert None not in index_to_words
    word2idx = {}
    for idx, token in enumerate(index_to_words):
        word2idx[token] = idx

    vocab = nlp.vocab.Vocab(word2idx,
                            token_to_idx=word2idx,
                            unknown_token=index_to_words[unk_idx],
                            padding_token=index_to_words[pad_idx],
                            bos_token=index_to_words[bos_idx],
                            eos_token=index_to_words[eos_idx],
                            mask_token=u'<mask>')
    return vocab