def forward(self, input_embeds_list): """ Args: input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim) Returns: hidden_states_list (list[SequenceBatchElement]) where each element is (batch_size, hidden_dim) """ batch_size = input_embeds_list[0].values.size()[0] h = tile_state(self.h0, batch_size) # (batch_size, hidden_dim) c = tile_state(self.c0, batch_size) # (batch_size, hidden_dim) hidden_states_list = [] for t, x in enumerate(input_embeds_list): # x.values has shape (batch_size, input_dim) # x.mask has shape (batch_size, 1) h_new, c_new = self.rnn_cell(x.values, (h, c)) h = gated_update(h, h_new, x.mask) c = gated_update(c, c_new, x.mask) hidden_states_list.append( SequenceBatchElement(self.dropout(h), x.mask)) return hidden_states_list
def initialize(self, batch_size): h = tile_state(self.h0, batch_size) c = tile_state(self.c0, batch_size) # no initial weights, context is just zero vector init_attn = lambda attention: AttentionOutput( None, GPUVariable(torch.zeros(batch_size, attention.memory_dim))) return AttentionRNNState([h] * self.num_layers, [c] * self.num_layers, init_attn(self.source_attention), init_attn(self.insert_attention), init_attn(self.delete_attention))
def initialize(self, batch_size): h = tile_state(self.h0, batch_size) c = tile_state(self.c0, batch_size) # no initial weights, context is just zero vector init_attn = lambda attention: AttentionOutput( weights=None, context=GPUVariable(torch.zeros(batch_size, attention.memory_dim)), logits=None) return AttentionRNNState( [h] * self.num_layers, [c] * self.num_layers, [init_attn(attn) for attn in self.input_attentions])
def initialize(self, batch_size): h = tile_state(self.h0, batch_size) c = tile_state(self.c0, batch_size) return MultilayeredRNNState([h] * self.num_layers, [c] * self.num_layers)
def initialize(self, batch_size): h = tile_state(self.h0, batch_size) c = tile_state(self.c0, batch_size) return SimpleRNNState(h, c)
def _encoder_output(self, batch_size): return tile_state(self.agenda, batch_size)
def test_tile_state(): h = GPUVariable(torch.FloatTensor([1, 2, 3])) h_tiled = tile_state(h, 3) assert_tensor_equal(h_tiled, [[1, 2, 3], [1, 2, 3], [1, 2, 3]])