def forward(self, input_embeds_list): """ Args: input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim) Returns: hidden_states_list (list[SequenceBatchElement]) where each element is (batch_size, hidden_dim) """ batch_size = input_embeds_list[0].values.size()[0] h = tile_state(self.h0, batch_size) # (batch_size, hidden_dim) c = tile_state(self.c0, batch_size) # (batch_size, hidden_dim) hidden_states_list = [] for t, x in enumerate(input_embeds_list): # x.values has shape (batch_size, input_dim) # x.mask has shape (batch_size, 1) h_new, c_new = self.rnn_cell(x.values, (h, c)) h = gated_update(h, h_new, x.mask) c = gated_update(c, c_new, x.mask) hidden_states_list.append( SequenceBatchElement(self.dropout(h), x.mask)) return hidden_states_list
def clear_cache(self): # Keep empty tuple cached, for SequenceBatch self._cache.clear() self._cache.cache( [tuple()], [ (GPUVariable(torch.zeros(self._embed_dim)), SequenceBatchElement( GPUVariable(torch.zeros(1, self._embed_dim)), GPUVariable(torch.zeros(1))) )])
def combined_states(self): """Concatenates forward and backward hidden states: [forward; backward]. Returns: combined_states (list[SequenceBatchElement]): ordered left to right """ combined_states = [ SequenceBatchElement(torch.cat([f.values, b.values], 1), f.mask) for f, b in izip(self.forward_states, self.backward_states) ] return combined_states
def forward(self, utterance): """Embeds a batch of utterances. Args: utterance (list[list[unicode]]): list[unicode] is a list of tokens forming a sentence. list[list[unicode]] is batch of sentences. Returns: Variable[FloatTensor]: batch x lstm_dim (concatenated first and last hidden states) list[SequenceBatchElement]: list of length batch, where each element's values is seq_len x embed_dim and mask is seq_len, representing the hidden states of each token. """ # Make keys hashable utterance = [tuple(utt) for utt in utterance] uncached_utterances = self._cache.uncached_keys(utterance) # Cache the uncached utterances if len(uncached_utterances) > 0: token_indices = SequenceBatch.from_sequences( uncached_utterances, self._token_embedder.vocab) # batch x seq_len x token_embed_dim token_embeds = self._token_embedder.embed_seq_batch(token_indices) bi_hidden_states = self._bilstm(token_embeds.split()) final_states = torch.cat(bi_hidden_states.final_states, 1) # Store the combined states in batch x stuff order for caching. combined_states = bi_hidden_states.combined_states # batch x seq_len x embed_dim combined_values = torch.stack( [state.values for state in combined_states], 1) # batch x seq_len combined_masks = torch.stack( [state.mask for state in combined_states], 1) assert len(combined_values) == len(combined_masks) combined_states_by_batch = [SequenceBatchElement( value, mask) for value, mask in zip( combined_values, combined_masks)] assert len(final_states) == len(combined_states_by_batch) # self._cache.cache( # uncached_utterances, # zip(final_states, combined_states_by_batch)) self._cache.cache( list(uncached_utterances), list(zip(final_states, combined_states_by_batch))) final_states, combined_states = zip(*self._cache.get(utterance)) return torch.stack(final_states, 0), combined_states
def test_cat(self): x1 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[1, 2], [3, 4]], [[8, 2], [9, 0]]])), GPUVariable(torch.FloatTensor([ [1], [1] ]))) x2 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[-1, 20], [3, 40]], [[-8, 2], [9, 10]]])), GPUVariable(torch.FloatTensor([ [1], [0] ]))) x3 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[-1, 20], [3, 40]], [[-8, 2], [9, 10]]])), GPUVariable(torch.FloatTensor([ [0], [0] ]))) result = SequenceBatch.cat([x1, x2, x3]) assert_tensor_equal(result.values, [ [[[1, 2], [3, 4]], [[-1, 20], [3, 40]], [[-1, 20], [3, 40]]], [[[8, 2], [9, 0]], [[-8, 2], [9, 10]], [[-8, 2], [9, 10]]], ]) assert_tensor_equal(result.mask, [ [1, 1, 0], [1, 0, 0] ])
def forward(self, input_embeds_list): """ Args: input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim) Returns: BidirectionalEncoderOutput """ for i, layer in enumerate(self.layers): if i == 0: prev_hidden_states = input_embeds_list else: prev_hidden_states = [ SequenceBatchElement(torch.cat([f.values, b.values], 1), f.mask) for f, b in izip(forward_states, backward_states) ] new_forward_states, new_backward_states = layer(prev_hidden_states) if i == 0: # no skip connections here, because dimensions don't match forward_states, backward_states = new_forward_states, new_backward_states else: # add residuals to previous hidden states add_residuals = lambda a_list, b_list: [ SequenceBatchElement(a.values + b.values, a.mask) for a, b in izip(a_list, b_list) ] forward_states = add_residuals(forward_states, new_forward_states) backward_states = add_residuals(backward_states, new_backward_states) return BidirectionalEncoderOutput(forward_states, backward_states)
def forward(self, encoder_output, train_decoder_input): """ Args: encoder_output (EncoderOutput) train_decoder_input (TrainDecoderInput) Returns: rnn_states (list[RNNState]) total_loss (Variable): a scalar loss """ batch_size, _ = train_decoder_input.input_words.mask.size() rnn_state = self.decoder_cell.initialize(batch_size) input_word_embeds = encoder_output.token_embedder.embed_seq_batch( train_decoder_input.input_words) input_embed_list = input_word_embeds.split() target_word_list = train_decoder_input.target_words.split() loss_list = [] rnn_states = [] vocab_probs = [] for t, (x, target_word) in enumerate( izip(input_embed_list, target_word_list)): # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable # update rnn state rnn_input = self.rnn_context_combiner(encoder_output, x.values) decoder_cell_output = self.decoder_cell(rnn_state, rnn_input, x.mask) rnn_state = decoder_cell_output.rnn_state rnn_states.append(rnn_state) vocab_pr = decoder_cell_output.vocab_probs vocab_probs.append(vocab_pr) # compute loss loss = decoder_cell_output.loss( target_word.values) # (batch_size,) loss_list.append(SequenceBatchElement(loss, x.mask)) losses = SequenceBatch.cat( loss_list) # (batch_size, target_seq_length) return vocab_probs, rnn_states, losses