def forward(self, sentence): def embedding(word): return [self.embedding_table(word.squeeze(-1))], [] embeddings, = rc.recurrent_group(seq_inputs=[sentence], step_func=embedding) bow = sequence_pooling(embeddings, pooling_type="sum") if self.one_more_hidden: return self.hidden(bow) return bow
def step_func(sentence, img, sentence_state, word_state): assert isinstance(sentence, list) def inner_step_func(w, ws): ### w is the current word emebdding ### ws is the current word state assert isinstance(w, torch.Tensor) assert isinstance(ws, torch.Tensor) ## return output and updated state return [w + ws.mean(-1).unsqueeze(-1)], [ws] outputs, word_states = rc.recurrent_group( seq_inputs=[sentence], insts=[], init_states=[word_state], step_func=inner_step_func, out_states=True) last_outputs = torch.stack([o[-1] for o in outputs]) last_word_states = torch.stack([s[-1] for s in word_states]) ## we compute the output by multipying the sentence state ## with the last word state out = last_outputs * sentence_state + img.mean(-1).unsqueeze(-1) return [out], [sentence_state * -1, last_word_states]
def test_multi_sequential_inputs(self): ## we first make a batch of paragraphs of sentences ## Let's suppose that each word has a 1d embedding sentences = [ ## paragraph 1 [ [[0.3], [0.4], [0.5]], ## sentence 1 [[0.1], [0.2]] ], ## sentence 2 ## paragraph 2 [ [[0.3], [0.4], [0.5]], ## sentence 3 [[0.2], [0.2]], ## sentence 4 [[1.0], [0.2], [0.4], [0.5]] ], ## sentence 5 ] sentence_tensors = rc.make_hierarchy_of_tensors( sentences, "float32", "cpu", [1]) ## we then make a batch of paragraphs of images images = [ ## paragraph 1 [ [0.1, 0.3], ## image 1 [0, -1] ], ## image 2 ## paragraph 2 [ [0, 1], ## image 3 [1, 0], ## image 4 [1, 1] ], ## image 5 ] image_tensors = rc.make_hierarchy_of_tensors(images, "float32", "cpu", [2]) states = [ [-2, -4, -6, -8], ## paragraph 1 [-1, -2, -3, -4], ## paragraph 2 ] state_tensors = rc.make_hierarchy_of_tensors(states, "float32", "cpu", [4]) def step_func(sentence, image, state): """ We compute the first output by doing outer product between the average word embedding and the image embedding. We compute the second output by adding the average word embedding and the image embedding. We directly add the state mean value to the output and update the state by multiplying it with -1. """ assert isinstance(sentence, list) assert isinstance(image, torch.Tensor) assert isinstance(state, torch.Tensor) sentence = torch.stack([sen.mean(0) for sen in sentence]) assert sentence.size()[0] == image.size()[0] mean_state = state.mean(-1).unsqueeze(-1) out1 = torch.bmm(sentence.unsqueeze(2), image.unsqueeze(1)).view( sentence.size()[0], -1) + mean_state out2 = sentence + image.mean(-1).unsqueeze(-1) + mean_state return [out1, out2], [state * -1] outs = rc.recurrent_group([sentence_tensors, image_tensors], [], [state_tensors], step_func, out_states=True) self.assertTrue( tensor_lists_equal( outs, [ [ torch.tensor([[-4.9600, -4.8800], [5.0000, 4.8500]]), torch.tensor([[-2.5000, -2.1000], [2.7000, 2.5000], [-1.9750, -1.9750]]) ], ## out1 [ torch.tensor([[-4.4000], [4.6500]]), torch.tensor([[-1.6000], [3.2000], [-0.9750]]) ], ## out2 [ torch.tensor([[2., 4., 6., 8.], [-2., -4., -6., -8.]]), torch.tensor([[1., 2., 3., 4.], [-1., -2., -3., -4.], [1., 2., 3., 4.]]) ] ## state ]))
def test_hierchical_sequences(self): sentences = [ ## paragraph 1 [ [[0.3], [0.4], [0.5]], ## sentence 1 [[0.1], [0.2]] ], ## sentence 2 ## paragraph 2 [ [[0.3], [0.4], [0.5]], ## sentence 3 [[0.2], [0.2]], ## sentence 4 [[1.0], [0.2], [0.4], [0.5]] ], ## sentence 5 ] imgs = [ [2.0, 2.0, 2.0], ## image 1 [1.0, 1.0, 1.0] ## image 2 ] sentence_tensors = rc.make_hierarchy_of_tensors( sentences, "float32", "cpu", [1]) img_tensors = rc.make_hierarchy_of_tensors(imgs, "float32", "cpu", [3]) sentence_states = [ [-2, -4, -6, -8], ## paragraph 1 [-1, -2, -3, -4], ## paragraph 2 ] sentence_state_tensors = rc.make_hierarchy_of_tensors( sentence_states, "float32", "cpu", [4]) word_states = [ [1.0, 1.0], ## sentence 1 [-1.0, -1.0], ## sentence 3 ] word_state_tensors = rc.make_hierarchy_of_tensors( word_states, "float32", "cpu", [2]) ## This hierarchical function does the following things: ## 1. For each word in each sentence, we add the word state ## to the word embedding, and the word state keeps the same all the time ## 2. We take the last output of the words and the word states ## 3. In the higher level, we multiply the last word output with the sentence state ## and add it to the mean of the static image input. We update the sentence state ## by multiplying it with -1 def step_func(sentence, img, sentence_state, word_state): assert isinstance(sentence, list) def inner_step_func(w, ws): ### w is the current word emebdding ### ws is the current word state assert isinstance(w, torch.Tensor) assert isinstance(ws, torch.Tensor) ## return output and updated state return [w + ws.mean(-1).unsqueeze(-1)], [ws] outputs, word_states = rc.recurrent_group( seq_inputs=[sentence], insts=[], init_states=[word_state], step_func=inner_step_func, out_states=True) last_outputs = torch.stack([o[-1] for o in outputs]) last_word_states = torch.stack([s[-1] for s in word_states]) ## we compute the output by multipying the sentence state ## with the last word state out = last_outputs * sentence_state + img.mean(-1).unsqueeze(-1) return [out], [sentence_state * -1, last_word_states] outs, sentence_states, word_states \ = rc.recurrent_group(seq_inputs=[sentence_tensors], insts=[img_tensors], init_states=[sentence_state_tensors, word_state_tensors], step_func=step_func, out_states=True) self.assertTrue( tensor_lists_equal(outs, [ torch.tensor([[-1.0, -4.0, -7.0, -10.0], [4.4, 6.8, 9.2, 11.6] ]), torch.tensor([[1.5, 2.0, 2.5, 3.0], [0.2, -0.6, -1.4, -2.2], [1.5, 2.0, 2.5, 3.0]]) ])) self.assertTrue( tensor_lists_equal(sentence_states, [ torch.tensor([[2., 4., 6., 8.], [-2., -4., -6., -8.]]), torch.tensor([[1., 2., 3., 4.], [-1., -2., -3., -4.], [1., 2., 3., 4.]]) ])) self.assertTrue( tensor_lists_equal(word_states, [ torch.tensor([[1., 1.], [1., 1.]]), torch.tensor([[-1., -1.], [-1., -1.], [-1., -1.]]) ]))