def __init__(self, model, cfg, vocabulary: Vocabulary): pc = model.add_subcollection() word_num = vocabulary.get_vocab_size('word') self.wlookup = pc.lookup_parameters_from_numpy( np.zeros((word_num, cfg.WORD_DIM), dtype=np.float32)) tag_num = vocabulary.get_vocab_size('tag') self.tlookup = pc.lookup_parameters_from_numpy( np.random.randn(tag_num, cfg.TAG_DIM).astype(np.float32)) _, glove_vec = glove_reader(cfg.GLOVE) glove_dim = len(glove_vec[0]) unk_pad_vec = [[0.0 for _ in range(glove_dim)]] glove_num = vocabulary.get_vocab_size('glove') glove_vec = unk_pad_vec + unk_pad_vec + glove_vec glove_vec = np.array(glove_vec, dtype=np.float32) / np.std(glove_vec) self.glookup = pc.lookup_parameters_from_numpy( glove_vec.astype(np.float32)) self.token_dim = cfg.WORD_DIM + cfg.TAG_DIM self.vocabulary = vocabulary self.pc, self.cfg = pc, cfg self.spec = (cfg, vocabulary)
def test_vocabulary(self): pretrained_vocabs = { 'glove': ['a', 'b', 'c'], 'w2v': ['b', 'c', 'd'], 'glove_nounk': ['a', 'b', 'c'], 'glove_nounk_nopad': ['a', 'b', 'c'] } counters = { 'w': Counter(["This", "is", "a", "test", "sentence", '.']), 'w_m': Counter(['This', 'is', 'is']), 'w_nounk': Counter(['This', 'is']), 'w_nounk_nopad': Counter(['This', 'is', 'a']) } vocab = Vocabulary( counters=counters, min_count={'w_m': 2}, pretrained_vocab=pretrained_vocabs, intersection_vocab={'w2v': 'glove'}, no_pad_namespace={'glove_nounk_nopad', 'w_nounk_nopad'}, no_unk_namespace={ 'glove_nounk', 'w_nounk', 'glove_nounk_nopad', 'w_nounk_nopad' }) # Test glove print(vocab.get_vocab_size('glove')) assert vocab.get_token_index('a', 'glove') == 2 assert vocab.get_token_index('c', 'glove') == 4 assert vocab.get_token_index('d', 'glove') == 0 # Test w2v assert vocab.get_token_index('b', 'w2v') == 2 assert vocab.get_token_index('d', 'w2v') == 0 assert vocab.get_token_from_index(2, 'w2v') == 'b' with pytest.raises(RuntimeError) as excinfo: vocab.get_token_from_index(4, 'w2v') assert excinfo.type == RuntimeError # Test glove_nounk assert vocab.get_token_index('a', 'glove_nounk') == 1 assert vocab.get_token_index('c', 'glove_nounk') == 3 with pytest.raises(RuntimeError) as excinfo: vocab.get_token_index('d', 'glove_nounk') assert excinfo.type == RuntimeError # Test glove_nounk_nopad assert vocab.get_token_index('a', 'glove_nounk_nopad') == 0 assert vocab.get_token_index('c', 'glove_nounk_nopad') == 2 with pytest.raises(RuntimeError) as excinfo: vocab.get_token_index('d', 'glove_nounk_nopad') assert excinfo.type == RuntimeError # Test w assert vocab.get_token_index('a', 'w') == 4 assert vocab.get_token_index('.', 'w') == 7 assert vocab.get_token_index('That', 'w') == 0 # Test w_m assert vocab.get_token_index('is', 'w_m') == 2 assert vocab.get_token_index('This', 'w_m') == 0 assert vocab.get_token_index('That', 'w_m') == 0 # Test w_nounk with pytest.raises(RuntimeError) as excinfo: vocab.get_token_index('That', 'w_nounk') assert excinfo.type == RuntimeError assert vocab.get_token_index('This', 'w_nounk') == 1 # Test w_nounk_nopad with pytest.raises(RuntimeError) as excinfo: vocab.get_token_index('That', 'w_nounk_nopad') assert excinfo.type == RuntimeError assert vocab.get_token_index('This', 'w_nounk_nopad') == 0