Пример #1
0
    def test_seq_length(self):
        tokens = '<unk> a b c'.split()
        unk = '<unk>'
        vocab = VocabExample(tokens, unk)
        sequences = [
            'a b a b c'.split(),  # more than length 4
            'a b'.split(),
            ['b'],
            ['c'],
        ]

        indices = np.array([
            [2, 1, 2, 3],
            [0, 0, 1, 2],
            [0, 0, 0, 2],
            [0, 0, 0, 3],
        ], dtype=np.int32)

        mask = np.array([
            [1, 1, 1, 1],
            [0, 0, 1, 1],
            [0, 0, 0, 1],
            [0, 0, 0, 1],
        ], dtype=np.float32)

        with clean_session():
            model = FeedSequenceBatch(align='right', seq_length=4)
            test_feed = model.inputs_to_feed_dict(sequences, vocab)
            correct = {model.values: indices, model.mask: mask}
            assert_array_collections_equal(correct, test_feed)

            indices = tf.identity(model.values)
            mask = tf.identity(model.mask)
            assert indices.get_shape().as_list() == [None, 4]
            assert mask.get_shape().as_list() == [None, 4]
Пример #2
0
    def test_right_align(self, inputs):
        indices = np.array([
            [1, 1, 2, 2, 3],
            [0, 0, 0, 1, 2],
            [0, 0, 0, 0, 2],
            [0, 0, 0, 0, 3],
        ], dtype=np.int32)

        mask = np.array([
            [1, 1, 1, 1, 1],
            [0, 0, 0, 1, 1],
            [0, 0, 0, 0, 1],
            [0, 0, 0, 0, 1],
        ], dtype=np.float32)

        with clean_session():
            model = FeedSequenceBatch(align='right')
            correct = {model.values: indices, model.mask: mask}

            args, kwargs = inputs
            test = model.inputs_to_feed_dict(*args, **kwargs)
            assert_array_collections_equal(correct, test)