Exemple #1
0
    def forward(self, sentence):
        def embedding(word):
            return [self.embedding_table(word.squeeze(-1))], []

        embeddings, = rc.recurrent_group(seq_inputs=[sentence],
                                         step_func=embedding)
        bow = sequence_pooling(embeddings, pooling_type="sum")
        if self.one_more_hidden:
            return self.hidden(bow)
        return bow
Exemple #2
0
        def step_func(sentence, img, sentence_state, word_state):
            assert isinstance(sentence, list)

            def inner_step_func(w, ws):
                ### w is the current word emebdding
                ### ws is the current word state
                assert isinstance(w, torch.Tensor)
                assert isinstance(ws, torch.Tensor)
                ## return output and updated state
                return [w + ws.mean(-1).unsqueeze(-1)], [ws]

            outputs, word_states = rc.recurrent_group(
                seq_inputs=[sentence],
                insts=[],
                init_states=[word_state],
                step_func=inner_step_func,
                out_states=True)

            last_outputs = torch.stack([o[-1] for o in outputs])
            last_word_states = torch.stack([s[-1] for s in word_states])
            ## we compute the output by multipying the sentence state
            ## with the last word state
            out = last_outputs * sentence_state + img.mean(-1).unsqueeze(-1)
            return [out], [sentence_state * -1, last_word_states]
Exemple #3
0
    def test_multi_sequential_inputs(self):
        ## we first make a batch of paragraphs of sentences
        ## Let's suppose that each word has a 1d embedding
        sentences = [  ## paragraph 1
            [
                [[0.3], [0.4], [0.5]],  ## sentence 1
                [[0.1], [0.2]]
            ],  ## sentence 2
            ## paragraph 2
            [
                [[0.3], [0.4], [0.5]],  ## sentence 3
                [[0.2], [0.2]],  ## sentence 4
                [[1.0], [0.2], [0.4], [0.5]]
            ],  ## sentence 5
        ]
        sentence_tensors = rc.make_hierarchy_of_tensors(
            sentences, "float32", "cpu", [1])

        ## we then make a batch of paragraphs of images
        images = [  ## paragraph 1
            [
                [0.1, 0.3],  ## image 1
                [0, -1]
            ],  ## image 2
            ## paragraph 2
            [
                [0, 1],  ## image 3
                [1, 0],  ## image 4
                [1, 1]
            ],  ## image 5
        ]
        image_tensors = rc.make_hierarchy_of_tensors(images, "float32", "cpu",
                                                     [2])

        states = [
            [-2, -4, -6, -8],  ## paragraph 1
            [-1, -2, -3, -4],  ## paragraph 2
        ]
        state_tensors = rc.make_hierarchy_of_tensors(states, "float32", "cpu",
                                                     [4])

        def step_func(sentence, image, state):
            """
            We compute the first output by doing outer product between the
            average word embedding and the image embedding.
            We compute the second output by adding the average word embedding
            and the image embedding.
            We directly add the state mean value to the output
            and update the state by multiplying it with -1.
            """
            assert isinstance(sentence, list)
            assert isinstance(image, torch.Tensor)
            assert isinstance(state, torch.Tensor)
            sentence = torch.stack([sen.mean(0) for sen in sentence])
            assert sentence.size()[0] == image.size()[0]

            mean_state = state.mean(-1).unsqueeze(-1)
            out1 = torch.bmm(sentence.unsqueeze(2), image.unsqueeze(1)).view(
                sentence.size()[0], -1) + mean_state
            out2 = sentence + image.mean(-1).unsqueeze(-1) + mean_state
            return [out1, out2], [state * -1]

        outs = rc.recurrent_group([sentence_tensors, image_tensors], [],
                                  [state_tensors],
                                  step_func,
                                  out_states=True)

        self.assertTrue(
            tensor_lists_equal(
                outs,
                [
                    [
                        torch.tensor([[-4.9600, -4.8800], [5.0000, 4.8500]]),
                        torch.tensor([[-2.5000, -2.1000], [2.7000, 2.5000],
                                      [-1.9750, -1.9750]])
                    ],  ## out1
                    [
                        torch.tensor([[-4.4000], [4.6500]]),
                        torch.tensor([[-1.6000], [3.2000], [-0.9750]])
                    ],  ## out2
                    [
                        torch.tensor([[2., 4., 6., 8.], [-2., -4., -6., -8.]]),
                        torch.tensor([[1., 2., 3., 4.], [-1., -2., -3., -4.],
                                      [1., 2., 3., 4.]])
                    ]  ## state
                ]))
Exemple #4
0
    def test_hierchical_sequences(self):
        sentences = [  ## paragraph 1
            [
                [[0.3], [0.4], [0.5]],  ## sentence 1
                [[0.1], [0.2]]
            ],  ## sentence 2
            ## paragraph 2
            [
                [[0.3], [0.4], [0.5]],  ## sentence 3
                [[0.2], [0.2]],  ## sentence 4
                [[1.0], [0.2], [0.4], [0.5]]
            ],  ## sentence 5
        ]
        imgs = [
            [2.0, 2.0, 2.0],  ## image 1
            [1.0, 1.0, 1.0]  ## image 2
        ]

        sentence_tensors = rc.make_hierarchy_of_tensors(
            sentences, "float32", "cpu", [1])
        img_tensors = rc.make_hierarchy_of_tensors(imgs, "float32", "cpu", [3])

        sentence_states = [
            [-2, -4, -6, -8],  ## paragraph 1
            [-1, -2, -3, -4],  ## paragraph 2
        ]
        sentence_state_tensors = rc.make_hierarchy_of_tensors(
            sentence_states, "float32", "cpu", [4])

        word_states = [
            [1.0, 1.0],  ## sentence 1
            [-1.0, -1.0],  ## sentence 3
        ]
        word_state_tensors = rc.make_hierarchy_of_tensors(
            word_states, "float32", "cpu", [2])

        ## This hierarchical function does the following things:
        ## 1. For each word in each sentence, we add the word state
        ##    to the word embedding, and the word state keeps the same all the time
        ## 2. We take the last output of the words and the word states
        ## 3. In the higher level, we multiply the last word output with the sentence state
        ##    and add it to the mean of the static image input. We update the sentence state
        ##    by multiplying it with -1
        def step_func(sentence, img, sentence_state, word_state):
            assert isinstance(sentence, list)

            def inner_step_func(w, ws):
                ### w is the current word emebdding
                ### ws is the current word state
                assert isinstance(w, torch.Tensor)
                assert isinstance(ws, torch.Tensor)
                ## return output and updated state
                return [w + ws.mean(-1).unsqueeze(-1)], [ws]

            outputs, word_states = rc.recurrent_group(
                seq_inputs=[sentence],
                insts=[],
                init_states=[word_state],
                step_func=inner_step_func,
                out_states=True)

            last_outputs = torch.stack([o[-1] for o in outputs])
            last_word_states = torch.stack([s[-1] for s in word_states])
            ## we compute the output by multipying the sentence state
            ## with the last word state
            out = last_outputs * sentence_state + img.mean(-1).unsqueeze(-1)
            return [out], [sentence_state * -1, last_word_states]

        outs, sentence_states, word_states \
            = rc.recurrent_group(seq_inputs=[sentence_tensors],
                                 insts=[img_tensors],
                                 init_states=[sentence_state_tensors,
                                              word_state_tensors],
                                 step_func=step_func,
                                 out_states=True)
        self.assertTrue(
            tensor_lists_equal(outs, [
                torch.tensor([[-1.0, -4.0, -7.0, -10.0], [4.4, 6.8, 9.2, 11.6]
                              ]),
                torch.tensor([[1.5, 2.0, 2.5, 3.0], [0.2, -0.6, -1.4, -2.2],
                              [1.5, 2.0, 2.5, 3.0]])
            ]))
        self.assertTrue(
            tensor_lists_equal(sentence_states, [
                torch.tensor([[2., 4., 6., 8.], [-2., -4., -6., -8.]]),
                torch.tensor([[1., 2., 3., 4.], [-1., -2., -3., -4.],
                              [1., 2., 3., 4.]])
            ]))
        self.assertTrue(
            tensor_lists_equal(word_states, [
                torch.tensor([[1., 1.], [1., 1.]]),
                torch.tensor([[-1., -1.], [-1., -1.], [-1., -1.]])
            ]))