def index_word2vec_with_vocab(filepath: str, vocab: Vocab, extend_vocab=True, unk=None, lowercase=False, init='uniform', normalize=None) -> torch.Tensor: """ Args: filepath: The path to pretrained embedding. vocab: The vocabulary from training set. extend_vocab: Unlock vocabulary of training set to add those tokens in pretrained embedding file. unk: UNK token. lowercase: Convert words in pretrained embeddings into lowercase. init: Indicate which initialization to use for oov tokens. normalize: ``True`` or a method to normalize the embedding matrix. Returns: An embedding matrix. """ pret_vocab, pret_matrix = load_word2vec_as_vocab_tensor(filepath) if unk and unk in pret_vocab: pret_vocab[vocab.safe_unk_token] = pret_vocab.pop(unk) if extend_vocab: vocab.unlock() for word in pret_vocab: vocab.get_idx(word.lower() if lowercase else word) vocab.lock() ids = [] unk_id_offset = 0 for word, idx in vocab.token_to_idx.items(): word_id = pret_vocab.get(word, None) # Retry lower case if word_id is None: word_id = pret_vocab.get(word.lower(), None) if word_id is None: word_id = len(pret_vocab) + unk_id_offset unk_id_offset += 1 ids.append(word_id) if unk_id_offset: unk_embeds = torch.zeros(unk_id_offset, pret_matrix.size(1)) if init and init != 'zeros': if init == 'uniform': init = embedding_uniform else: raise ValueError(f'Unsupported init {init}') unk_embeds = init(unk_embeds) pret_matrix = torch.cat([pret_matrix, unk_embeds]) ids = torch.LongTensor(ids) embedding = pret_matrix.index_select(0, ids) if normalize == 'norm': embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) elif normalize == 'std': embedding /= torch.std(embedding) return embedding
def vocab_from_tsv(tsv_file_path, lower=False, lock_word_vocab=False, lock_char_vocab=True, lock_tag_vocab=True) \ -> Tuple[Vocab, Vocab, Vocab]: word_vocab = Vocab() char_vocab = Vocab() tag_vocab = Vocab(unk_token=None) with open(tsv_file_path, encoding='utf-8') as tsv_file: for line in tsv_file: cells = line.strip().split() if cells: word, tag = cells if lower: word_vocab.add(word.lower()) else: word_vocab.add(word) char_vocab.update(list(word)) tag_vocab.add(tag) if lock_word_vocab: word_vocab.lock() if lock_char_vocab: char_vocab.lock() if lock_tag_vocab: tag_vocab.lock() return word_vocab, char_vocab, tag_vocab