Beispiel #1
0
def index_word2vec_with_vocab(filepath: str,
                              vocab: Vocab,
                              extend_vocab=True,
                              unk=None,
                              lowercase=False,
                              init='uniform',
                              normalize=None) -> torch.Tensor:
    """

    Args:
        filepath: The path to pretrained embedding.
        vocab: The vocabulary from training set.
        extend_vocab: Unlock vocabulary of training set to add those tokens in pretrained embedding file.
        unk: UNK token.
        lowercase: Convert words in pretrained embeddings into lowercase.
        init: Indicate which initialization to use for oov tokens.
        normalize: ``True`` or a method to normalize the embedding matrix.

    Returns:
        An embedding matrix.

    """
    pret_vocab, pret_matrix = load_word2vec_as_vocab_tensor(filepath)
    if unk and unk in pret_vocab:
        pret_vocab[vocab.safe_unk_token] = pret_vocab.pop(unk)
    if extend_vocab:
        vocab.unlock()
        for word in pret_vocab:
            vocab.get_idx(word.lower() if lowercase else word)
    vocab.lock()
    ids = []

    unk_id_offset = 0
    for word, idx in vocab.token_to_idx.items():
        word_id = pret_vocab.get(word, None)
        # Retry lower case
        if word_id is None:
            word_id = pret_vocab.get(word.lower(), None)
        if word_id is None:
            word_id = len(pret_vocab) + unk_id_offset
            unk_id_offset += 1
        ids.append(word_id)
    if unk_id_offset:
        unk_embeds = torch.zeros(unk_id_offset, pret_matrix.size(1))
        if init and init != 'zeros':
            if init == 'uniform':
                init = embedding_uniform
            else:
                raise ValueError(f'Unsupported init {init}')
            unk_embeds = init(unk_embeds)
        pret_matrix = torch.cat([pret_matrix, unk_embeds])
    ids = torch.LongTensor(ids)
    embedding = pret_matrix.index_select(0, ids)
    if normalize == 'norm':
        embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
    elif normalize == 'std':
        embedding /= torch.std(embedding)
    return embedding
Beispiel #2
0
def vocab_from_tsv(tsv_file_path, lower=False, lock_word_vocab=False, lock_char_vocab=True, lock_tag_vocab=True) \
        -> Tuple[Vocab, Vocab, Vocab]:
    word_vocab = Vocab()
    char_vocab = Vocab()
    tag_vocab = Vocab(unk_token=None)
    with open(tsv_file_path, encoding='utf-8') as tsv_file:
        for line in tsv_file:
            cells = line.strip().split()
            if cells:
                word, tag = cells
                if lower:
                    word_vocab.add(word.lower())
                else:
                    word_vocab.add(word)
                char_vocab.update(list(word))
                tag_vocab.add(tag)
    if lock_word_vocab:
        word_vocab.lock()
    if lock_char_vocab:
        char_vocab.lock()
    if lock_tag_vocab:
        tag_vocab.lock()
    return word_vocab, char_vocab, tag_vocab