def __init__(self, model, cfg, vocabulary: Vocabulary):

        pc = model.add_subcollection()
        word_num = vocabulary.get_vocab_size('word')
        self.wlookup = pc.lookup_parameters_from_numpy(
            np.zeros((word_num, cfg.WORD_DIM), dtype=np.float32))
        tag_num = vocabulary.get_vocab_size('tag')
        self.tlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(tag_num, cfg.TAG_DIM).astype(np.float32))
        _, glove_vec = glove_reader(cfg.GLOVE)
        glove_dim = len(glove_vec[0])
        unk_pad_vec = [[0.0 for _ in range(glove_dim)]]
        glove_num = vocabulary.get_vocab_size('glove')
        glove_vec = unk_pad_vec + unk_pad_vec + glove_vec
        glove_vec = np.array(glove_vec, dtype=np.float32) / np.std(glove_vec)
        self.glookup = pc.lookup_parameters_from_numpy(
            glove_vec.astype(np.float32))

        self.token_dim = cfg.WORD_DIM + cfg.TAG_DIM
        self.vocabulary = vocabulary
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)
示例#2
0
    def test_vocabulary(self):
        pretrained_vocabs = {
            'glove': ['a', 'b', 'c'],
            'w2v': ['b', 'c', 'd'],
            'glove_nounk': ['a', 'b', 'c'],
            'glove_nounk_nopad': ['a', 'b', 'c']
        }

        counters = {
            'w': Counter(["This", "is", "a", "test", "sentence", '.']),
            'w_m': Counter(['This', 'is', 'is']),
            'w_nounk': Counter(['This', 'is']),
            'w_nounk_nopad': Counter(['This', 'is', 'a'])
        }

        vocab = Vocabulary(
            counters=counters,
            min_count={'w_m': 2},
            pretrained_vocab=pretrained_vocabs,
            intersection_vocab={'w2v': 'glove'},
            no_pad_namespace={'glove_nounk_nopad', 'w_nounk_nopad'},
            no_unk_namespace={
                'glove_nounk', 'w_nounk', 'glove_nounk_nopad', 'w_nounk_nopad'
            })

        # Test glove
        print(vocab.get_vocab_size('glove'))
        assert vocab.get_token_index('a', 'glove') == 2
        assert vocab.get_token_index('c', 'glove') == 4
        assert vocab.get_token_index('d', 'glove') == 0

        # Test w2v
        assert vocab.get_token_index('b', 'w2v') == 2
        assert vocab.get_token_index('d', 'w2v') == 0
        assert vocab.get_token_from_index(2, 'w2v') == 'b'
        with pytest.raises(RuntimeError) as excinfo:
            vocab.get_token_from_index(4, 'w2v')
        assert excinfo.type == RuntimeError

        # Test glove_nounk
        assert vocab.get_token_index('a', 'glove_nounk') == 1
        assert vocab.get_token_index('c', 'glove_nounk') == 3
        with pytest.raises(RuntimeError) as excinfo:
            vocab.get_token_index('d', 'glove_nounk')
        assert excinfo.type == RuntimeError

        # Test glove_nounk_nopad
        assert vocab.get_token_index('a', 'glove_nounk_nopad') == 0
        assert vocab.get_token_index('c', 'glove_nounk_nopad') == 2
        with pytest.raises(RuntimeError) as excinfo:
            vocab.get_token_index('d', 'glove_nounk_nopad')
        assert excinfo.type == RuntimeError

        # Test w
        assert vocab.get_token_index('a', 'w') == 4
        assert vocab.get_token_index('.', 'w') == 7
        assert vocab.get_token_index('That', 'w') == 0

        # Test w_m
        assert vocab.get_token_index('is', 'w_m') == 2
        assert vocab.get_token_index('This', 'w_m') == 0
        assert vocab.get_token_index('That', 'w_m') == 0

        # Test w_nounk
        with pytest.raises(RuntimeError) as excinfo:
            vocab.get_token_index('That', 'w_nounk')
        assert excinfo.type == RuntimeError
        assert vocab.get_token_index('This', 'w_nounk') == 1

        # Test w_nounk_nopad
        with pytest.raises(RuntimeError) as excinfo:
            vocab.get_token_index('That', 'w_nounk_nopad')
        assert excinfo.type == RuntimeError
        assert vocab.get_token_index('This', 'w_nounk_nopad') == 0