def test_embed(self): sequences = [ [], [1, 2, 3], [3, 3], [2] ] vocab = SimpleVocab([0, 1, 2, 3, 4]) indices = SequenceBatch.from_sequences(sequences, vocab) embeds = GPUVariable(torch.FloatTensor([ [0, 0], [2, 2], # 1 [3, 4], # 2 [-10, 1], # 3 [11, -1] # 4 ])) embedded = SequenceBatch.embed(indices, embeds) correct = np.array([ [[0, 0], [0, 0], [0, 0]], [[2, 2], [3, 4], [-10, 1]], [[-10, 1], [-10, 1], [0, 0]], [[3, 4], [0, 0], [0, 0]] ], dtype=np.float32) assert_tensor_equal(embedded.values, correct)
def test_reduce_max(self, some_seq_batch): with pytest.raises(ValueError): # should complain about empty sequence SequenceBatch.reduce_max(some_seq_batch) values = GPUVariable( torch.FloatTensor([ [ [1, 2], [4, 5], [4, 4] ], # actual max is in later elements, but shd be suppressed by mask [[0, -4], [43, -5], [-1, -20]], # note that all elements in 2nd dim are negative ])) mask = GPUVariable(torch.FloatTensor([ [1, 0, 0], [1, 1, 0], ])) seq_batch = SequenceBatch(values, mask) result = SequenceBatch.reduce_max(seq_batch) assert_tensor_equal(result, [ [1, 2], [43, -4], ])
def forward(self, insert_embeds, insert_embeds_exact, delete_embeds, delete_embeds_exact, draw_samples = False, draw_p = False): """Create agenda vector. Args: insert_embeds (SequenceBatch): of shape (batch_size, max_edits, word_dim) insert_embeds_exact (SequenceBatch): of shape (batch_size, max_edits, word_dim) delete_embeds (SequenceBatch): of shape (batch_size, max_edits, word_dim) delete_embeds_exact (SequenceBatch): of shape (batch_size, max_edits, word_dim) draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time. Returns: edit_embed (Variable): of shape (batch_size, edit_vec_cim) """ insert_embed = SequenceBatch.reduce_sum(insert_embeds) # (batch_size, word_dim) insert_embed += SequenceBatch.reduce_sum(insert_embeds_exact) # (batch_size, word_dim) delete_embed = SequenceBatch.reduce_sum(delete_embeds) # (batch_size, word_dim) delete_embed += SequenceBatch.reduce_sum(delete_embeds_exact) # (batch_size, word_dim) insert_set = self.linear_prenoise(insert_embed) delete_set = self.linear_prenoise(delete_embed) combined_map = torch.cat([insert_set, delete_set], 1) if draw_samples: if draw_p: batch_size, edit_dim = combined_map.size() combined_map = self.draw_p_noise(batch_size, edit_dim) else: combined_map = self.sample_vMF(combined_map, self.noise_scaler) edit_embed = combined_map return edit_embed
def base_plus_copy_indices(words, dynamic_vocabs, base_vocab, volatile=False): """Compute base + copy indices. Args: words (list[list[unicode]]) dynamic_vocabs (list[HardCopyDynamicVocab]) base_vocab (HardCopyVocab) volatile (bool) Returns: MultiVocabIndices """ unk = base_vocab.UNK copy_seqs = [] for seq, dyna_vocab in izip(words, dynamic_vocabs): word_to_copy = dyna_vocab.word_to_copy_token normal_copy_seq = [] for w in seq: normal_copy_seq.append(word_to_copy.get(w, unk)) copy_seqs.append(normal_copy_seq) # each SeqBatch.values has shape (batch_size, seq_length) base_indices = SequenceBatch.from_sequences(words, base_vocab, volatile=volatile) copy_indices = SequenceBatch.from_sequences(copy_seqs, base_vocab, volatile=volatile) assert_tensor_equal(base_indices.mask, copy_indices.mask) # has shape (batch_size, seq_length, 2) concat_values = torch.stack([base_indices.values, copy_indices.values], 2) return MultiVocabIndices(concat_values, base_indices.mask)
def test_reduce_mean(self, some_seq_batch): result = SequenceBatch.reduce_mean(some_seq_batch, allow_empty=True) assert_tensor_equal(result, [[2.5, 3.5], [0, 4], [0, 0]]) with pytest.raises(ValueError): SequenceBatch.reduce_mean(some_seq_batch, allow_empty=False)
def test_split(self): input_embeds = GPUVariable(torch.LongTensor([ # batch item 1 [ [1, 2], [2, 3], [5, 6] ], # batch item 2 [ [4, 8], [3, 5], [0, 0] ], ])) input_mask = GPUVariable(torch.FloatTensor([ [1, 1, 1], [1, 1, 0], ])) sb = SequenceBatch(input_embeds, input_mask) elements = sb.split() input_list = [e.values for e in elements] mask_list = [e.mask for e in elements] assert len(input_list) == 3 assert_tensor_equal(input_list[0], [[1, 2], [4, 8]]) assert_tensor_equal(input_list[1], [[2, 3], [3, 5]]) assert_tensor_equal(input_list[2], [[5, 6], [0, 0]]) assert len(mask_list) == 3 assert_tensor_equal(mask_list[0], [[1], [1]]) assert_tensor_equal(mask_list[1], [[1], [1]]) assert_tensor_equal(mask_list[2], [[1], [0]])
def preprocess(self, source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, edit_embed): """Preprocess. Args: source_words (list[list[unicode]]): a batch of source sequences insert_words (list[list[unicode]]): a batch of insert words insert_exact_words (list[list[unicode]]): a batch of insert words, used without noise delete_words (list[list[unicode]]): a batch of delete words delete_exact_words (list[list[unicode]]): a batch of delete words, used without noise edit_embed (np.ndarray | None): of shape (batch_size, edit_dim), or None. Returns: EncoderInput """ return EncoderInput( SequenceBatch.from_sequences(source_words, self.word_vocab), SequenceBatch.from_sequences(insert_words, self.word_vocab, min_seq_length=1), SequenceBatch.from_sequences(insert_exact_words, self.word_vocab, min_seq_length=1), SequenceBatch.from_sequences(delete_words, self.word_vocab, min_seq_length=1), SequenceBatch.from_sequences(delete_exact_words, self.word_vocab, min_seq_length=1), edit_embed)
def forward(self, utterances): """Embeds a batch of utterances. Args: utterances (list[list[unicode]]): list[unicode] is a list of tokens forming a sentence. list[list[unicode]] is batch of sentences. Returns: Variable[FloatTensor]: batch x lstm_dim (concatenated first and last hidden states) """ # Cut to max_words + look up indices utterances = [ utterance[:self._max_words] + [EOS] for utterance in utterances ] token_indices = SequenceBatch.from_sequences( utterances, self._token_embedder.vocab) # batch x seq_len x token_embed_dim token_embeds = self._token_embedder.embed_seq_batch(token_indices) # print('token_embeds', token_embeds) bi_hidden_states = self._bilstm(token_embeds.split()) final_states = torch.cat(bi_hidden_states.final_states, 1) hidden_states = SequenceBatch.cat(bi_hidden_states.combined_states) return self._attention(hidden_states, final_states).context
def __init__(self, target_words, word_vocab, keep_rate): input_words = [[word_vocab.START] + tokens for tokens in target_words] target_words_shifted = [tokens + [word_vocab.STOP] for tokens in target_words] input_words = SequenceBatch.from_sequences(input_words, word_vocab) self.input_words = self._drop_seq_batch(input_words, word_vocab, keep_rate) self.target_words = SequenceBatch.from_sequences(target_words_shifted, word_vocab)
def forward(self, old_embeds, neighbors, rels): batch_size = len(old_embeds) neighbor_embeds = torch.index_select(old_embeds, 0, neighbors.values.view(-1)) neighbor_embeds = neighbor_embeds.view(batch_size, neighbors.values.shape[1], -1) neighbor_embeds = SequenceBatch(neighbor_embeds, neighbors.mask) pooled = SequenceBatch.reduce_max(neighbor_embeds) combined = torch.cat((old_embeds, pooled), dim=1) return F.relu(self._proj(self._dropout(combined)))
def _query_embeds(self, states, query_entries): """Given a batch of states, embed the keys and values of each state's query. Args: states (list[MiniWoBState]) Returns: entry_embeds (SequenceBatch): batch x num_keys x (2 * embed_dim) the keys and values concatenated """ fields_batch = [state.fields for state in states] # list[list[list[unicode]]] (batch x num_keys x key length) values_batch = [[word_tokenize(value) for value in fields.values] for fields in fields_batch] keys_batch = [[word_tokenize(key) for key in fields.keys] for fields in fields_batch] # Pad batch_size = len(fields_batch) max_num_fields = max(len(values) for values in values_batch) max_num_fields = max(max_num_fields, 1) # Ensure non-empty mask = torch.ones(batch_size, max_num_fields) assert len(keys_batch) == len(values_batch) == len(mask) for keys, values, submask in zip(keys_batch, values_batch, mask): assert len(keys) == len(values) if len(keys) < max_num_fields: submask[len(keys):] = 0. keys.extend( [[UtteranceVocab.PAD] for _ in xrange( max_num_fields - len(keys))]) values.extend( [[UtteranceVocab.PAD] for _ in xrange( max_num_fields - len(values))]) # Flatten to list[list[unicode]] (batch * num_keys) x key length keys_batch = flatten(keys_batch) values_batch = flatten(values_batch) # Embed and mask (batch * num_keys) x embed_dim key_embeds, _ = self._utterance_embedder(keys_batch) key_embeds = key_embeds.view( batch_size, max_num_fields, self._utterance_embedder.embed_dim) value_embeds, _ = self._utterance_embedder(values_batch) value_embeds = value_embeds.view( batch_size, max_num_fields, self._utterance_embedder.embed_dim) key_embeds = SequenceBatch(key_embeds, GPUVariable(mask)) value_embeds = SequenceBatch(value_embeds, GPUVariable(mask)) entry_embed_values = torch.cat( [key_embeds.values, value_embeds.values], 2) entry_embeds = SequenceBatch(entry_embed_values, key_embeds.mask) return entry_embeds
def forward(self, old_embeds, neighbors, rels): batch_size = len(old_embeds) projected = F.relu(self._proj(self._dropout(old_embeds))) neighbor_embeds = torch.index_select(projected, 0, neighbors.values.view(-1)) neighbor_embeds = neighbor_embeds.view(batch_size, neighbors.values.shape[1], -1) combined = torch.cat((projected.unsqueeze(1), neighbor_embeds), dim=1) mask = torch.cat((V(torch.ones(batch_size, 1)), neighbors.mask), dim=1) combined = SequenceBatch(combined, mask) return SequenceBatch.reduce_max(combined)
def __init__(self, target_words, word_vocab): """Create TrainDecoderInput. Args: target_words (list[list[unicode]]) word_vocab (WordVocab) """ input_words = [[word_vocab.START] + tokens for tokens in target_words] # prepend with <start> token target_words_shifted = [tokens + [word_vocab.STOP] for tokens in target_words] # append with <stop> token self.input_words = SequenceBatch.from_sequences(input_words, word_vocab) self.target_words = SequenceBatch.from_sequences(target_words_shifted, word_vocab)
def embed_seq_batch(self, seq_batch): """Embed elements of a SequenceBatch. Args: seq_batch (SequenceBatch) Returns: SequenceBatch """ if torch.cuda.is_available(): return SequenceBatch(self._embedding(seq_batch.values.cuda()), seq_batch.mask) else: return SequenceBatch(self._embedding(seq_batch.values), seq_batch.mask)
def forward(self, dom_elements, alignment_fields): """Computes the alignments. An element aligns iff elem.text in utterance and elem.text != "" Args: dom_elements (list[list[DOMElement]]): batch of set of DOM elements (padded to be unragged) alignment_fields (list[Fields]): batch of fields. Alignments computed with the values of the fields. Returns: Variable[FloatTensor]: batch x num_elems x embed_dim The aligned embeddings per DOM element """ batch_size = len(dom_elements) assert batch_size > 0 num_dom_elems = len(dom_elements[0]) assert num_dom_elems > 0 # mask batch_size x num_dom_elems x num_buckets alignments = np.zeros( (batch_size, num_dom_elems, self._num_buckets)).astype(np.float32) # Calculate the alignment matrix between elems and fields for batch_idx in xrange(len(dom_elements)): for dom_idx, dom in enumerate(dom_elements[batch_idx]): keys = alignment_fields[batch_idx].keys vals = alignment_fields[batch_idx].values for key, val in zip(keys, vals): if dom.text and dom.text in val: align_idx = self._keys2index.word2index(key) alignments[batch_idx, dom_idx, align_idx] = 1. # Flatten alignments for SequenceBatch # (batch * num_dom_elems) x num_buckets alignments = GPUVariable( torch.from_numpy( alignments.reshape( (batch_size * num_dom_elems, self._num_buckets)))) # (batch * num_dom_elems) x num_buckets x embed_dim expanded_alignment_embeds = self._alignment_embeds.expand( batch_size * num_dom_elems, self._num_buckets, self.embed_dim) alignment_seq_batch = SequenceBatch(expanded_alignment_embeds, alignments, left_justify=False) # (batch * num_dom_elems) x alignment_embed_dim alignment_embeds = SequenceBatch.reduce_sum(alignment_seq_batch) return alignment_embeds.view(batch_size, num_dom_elems, self.embed_dim)
def seq_batch_noise(self, seq_batch, draw_noise): """ Returns a noisy version of seq_batch, in which every vector is noisy and unit norm. :param seq_batch(SequenceBatch): a sequence batch of elements :return: noisy version of seq-batch """ values = seq_batch.values mask = seq_batch.mask batch_size, max_edits, w_embed_size = values.size() new_values = GPUVariable( torch.from_numpy( np.zeros((batch_size, max_edits, w_embed_size), dtype=np.float32))) m_expand = mask.unsqueeze(2).expand(batch_size, max_edits, w_embed_size) for max_edit in range(max_edits): phint = self.sample_vMF(values[:, max_edit, :], self.noise_scaler) prand = self.draw_p_noise(batch_size, w_embed_size) new_values[:, max_edit, :] = phint * m_expand[:, max_edit, :] + prand * ( 1 - m_expand[:, max_edit, :] ) return SequenceBatch(values=new_values * draw_noise, mask=mask)
def _dom_embeds(self, states): """Returns the DOM embeddings for a batch of states. Only embeds leaf DOM elements, and pads them. Args: states (list[MiniWoBState]) Returns: dom_embeds (SequenceBatch): batch x num_elems x embed_dim dom_elems (list[list[DOMElement]]): of shape (batch_size, num_elems). Padded with DOMElementPAD. """ leaf_elems = [ [elem for elem in state.dom_elements if elem.is_leaf] for state in states] dom_elems, dom_mask = self._pad_elements(leaf_elems) # batch x num_dom_elems x base_dom_embed_dim base_dom_embeds = self._base_dom_embedder(dom_elems) # batch x num_dom_elems x field_key_embed_dim dom_alignment_vectors = self._dom_to_field_alignment( dom_elems, self._alignment_fields(states)) # batch x num_dom_elems x (base + fields embed dim) aligned_dom_embeds = torch.cat( [base_dom_embeds, dom_alignment_vectors], 2) higher_order_dom_embeds = self._higher_order_dom_embedder( dom_elems, aligned_dom_embeds) return SequenceBatch( higher_order_dom_embeds, dom_mask, left_justify=False), dom_elems
def generate_edits(self, encoder_input, norm): """ Draw uniform random vectors with given norm, and use as edit vector """ source_words = encoder_input.source_words source_word_embeds = self.token_embedder.embed_seq_batch(source_words) insert_embeds = self.token_embedder.embed_seq_batch( encoder_input.insert_words) delete_embeds = self.token_embedder.embed_seq_batch( encoder_input.delete_words) insert_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.insert_exact_words) delete_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.delete_exact_words) source_encoder_output = self.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) edit_encoded = self.edit_encoder(insert_embeds, delete_embeds) rand_vec = torch.randn(edit_encoded.shape()) edit_embed = GPUVariable( rand_vec / torch.norm(rand_vec, 2, dim=1).expand_as(rand_vec) * norm) agenda = self.agenda_maker(source_embeds_final, edit_embed) return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, agenda)
def encoder_generate_edits(self, encoder_input): """ Draw uniform random vectors with given norm, and use as edit vector """ source_words = encoder_input.source_words source_word_embeds = self.editor.encoder.token_embedder.embed_seq_batch(source_words) insert_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_words) delete_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_words) insert_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_exact_words) delete_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_exact_words) source_encoder_output = self.editor.encoder.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) edit_encoded = self.editor.encoder.edit_encoder(insert_embeds, insert_embeds_exact, delete_embeds, delete_embeds_exact) # the random vector is computed as in rand_p_noise (see in edit_encoder) torch.manual_seed(7) batch_size, edit_dim = edit_encoded.size() rand_draw = GPUVariable(torch.randn(batch_size, edit_dim)) rand_draw = rand_draw / torch.norm(rand_draw, p=2, dim=1).expand(batch_size, edit_dim) rand_norms = (torch.rand(batch_size, 1) * self.editor.encoder.edit_encoder.norm_max).expand(batch_size, edit_dim) edit_embed = rand_draw * GPUVariable(rand_norms) agenda = self.editor.encoder.agenda_maker(source_embeds_final, edit_embed) return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, agenda)
def _get_neighbor_indices(self, dom_elements, is_neighbor): """Compute neighbor indices. Args: dom_elements (list[DOMElement]): may include PAD elements is_neighbor (Callable: DOMElement x DOMElement --> bool): True if two DOM elements are neighbors of each other, otherwise False Returns: SequenceBatch: of shape (total_dom_elems, max_neighbors) """ dom_element_ids = [id(e) for e in flatten(dom_elements)] dom_element_ids_set = set(dom_element_ids) vocab = SuperSimpleVocab(dom_element_ids) neighbors_batch = [] for dom_batch in dom_elements: for dom_elem in dom_batch: # Optimization: no DOM PAD has neighbors if isinstance(dom_elem, DOMElementPAD): neighbors = [] else: neighbors = [] for neighbor in dom_batch: if is_neighbor(dom_elem, neighbor): neighbors.append(id(neighbor)) neighbors_batch.append(neighbors) neighbor_indices = SequenceBatch.from_sequences(neighbors_batch, vocab, min_seq_length=1) return neighbor_indices
def test_reduce_prod(self, some_seq_batch): result = SequenceBatch.reduce_prod(some_seq_batch) assert_tensor_equal(result, [ [4, 10], [0, 4], [1, 1] ])
def test_log_sum_exp(self): values = GPUVariable(torch.FloatTensor([ [0, 1, -2, -3], [-2, -5, 1, 0], ])) mask = GPUVariable(torch.FloatTensor([ [1, 1, 1, 0], [1, 1, 0, 0], ])) seq_batch = SequenceBatch(values, mask, left_justify=False) result = SequenceBatch.log_sum_exp(seq_batch) correct = [1.3490122167681864, -1.9514126484262577] assert_tensor_equal(result, correct)
def forward(self, memory_cells, query): """Performs sentinel attention with a sentinel of 0. Returns the AttentionOutput where the weights do not include the sentinel weight. Args: memory_cells (Variable[FloatTensor]): batch x num_cells x cell_dim query (Variable[FloatTensor]): batch x query_dim Returns: AttentionOutput: weights do not include sentinel weights """ batch_size, _, cell_dim = memory_cells.values.size() sentinel = self._sentinel_embed.expand(batch_size, 1, cell_dim) sentinel_mask = GPUVariable(torch.ones(batch_size, 1)) cell_values_with_sentinel = torch.cat([memory_cells.values, sentinel], 1) cell_masks_with_sentinel = torch.cat( [memory_cells.mask, sentinel_mask], 1) cells_with_sentinel = SequenceBatch(cell_values_with_sentinel, cell_masks_with_sentinel, left_justify=False) attention_output = super(SentinelAttention, self).forward(cells_with_sentinel, query) weights_with_sentinel = attention_output.weights # TODO: Bring this line in after torch v0.2.0 # weights_without_sentinel = weights_with_sentinel[batch_size, :-1] # attention_output = AttentionOutput( # weights=weights_without_sentinel, context=attention_output.context) return attention_output
def memory_cells(self): mem_values = float_tensor_var( [ # (batch_size x num_cells x memory_dim) [ [.1, .2, .3, .4], [.4, .5, .6, .7], ], [ [.2, .3, .4, .5], [.6, .7, .8, .9], ], [ [.3, .4, .5, .6], [.7, .8, .9, .1], ], [ [-8, -9, -10, -11], [-12, -13, -14, -15], ], [ [8, 9, 10, 11], [12, 13, 14, 15], ] ]) mem_mask = float_tensor_var([ [1, 0], [1, 1], [1, 0], [0, 0], # empty row [0, 1], # right-justified ]) memory_cells = SequenceBatch(values=mem_values, mask=mem_mask, left_justify=False) return memory_cells
def test_reduce_sum(self, some_seq_batch): result = SequenceBatch.reduce_sum(some_seq_batch) assert_tensor_equal(result, [ [5, 7], [0, 4], [0, 0], ])
def _get_neighbors(self, web_page): """Get indices of at most |max_neighbors| neighbors for each relation Args: web_page (WebPage) Returns: neighbors: SequenceBatch of shape num_nodes x ??? containing the neighbor refs (??? is at most max_neighbors * len(neighbor_rels)) rels: SequenceBatch of shape num_nodes x ??? containing the relation indices """ g = web_page.graph batch_neighbors = [[] for _ in range(len(web_page.nodes))] batch_rels = [[] for _ in range(len(web_page.nodes))] for src, tgts in g.nodes.items(): # Group by relation rel_to_tgts = defaultdict(list) for tgt, rels in tgts.items(): for rel in rels: rel_to_tgts[rel].append(tgt) # Sample if needed for rel, index in self._neighbor_rels.items(): tgts = rel_to_tgts[rel] random.shuffle(tgts) if not tgts: continue if len(tgts) > self._max_neighbors: tgts = tgts[:self._max_neighbors] batch_neighbors[src].extend(tgts) batch_rels[src].extend([index] * len(tgts)) # Create SequenceBatches max_len = max(len(x) for x in batch_neighbors) batch_mask = [] for neighbors, rels in zip(batch_neighbors, batch_rels): assert len(neighbors) == len(rels) this_len = len(neighbors) batch_mask.append([1.] * this_len + [0.] * (max_len - this_len)) neighbors.extend([0] * (max_len - this_len)) rels.extend([0] * (max_len - this_len)) return (SequenceBatch( V(torch.tensor(batch_neighbors, dtype=torch.long)), V(torch.tensor(batch_mask, dtype=torch.float32))), SequenceBatch(V(torch.tensor(batch_rels, dtype=torch.long)), V(torch.tensor(batch_mask, dtype=torch.float32))))
def split(self): """Split alignments object into per-time-step alignments. Returns: list[SequenceBatch]: where each element has shape (batch_size, max_alignments) """ indices_list = [v.squeeze(1) for v in self.indices.split(1, dim=1)] mask_list = [v.squeeze(1) for v in self.mask.split(1, dim=1)] return [SequenceBatch(i, m) for i, m in zip(indices_list, mask_list)]
def _drop_seq_batch(self, seq_batch, word_vocab, keep_rate): batch_sz, max_seq_len = seq_batch.values.size() keep = torch.rand(batch_sz, max_seq_len) < keep_rate keep[:,0] = torch.ones(batch_sz, 1) # do not drop start token kept = seq_batch.values * GPUVariable(torch.ByteTensor.long(keep)) unkd = GPUVariable(torch.ByteTensor.long((1 - keep) * word_vocab.word2index(word_vocab.UNK))) values = kept + unkd return SequenceBatch(values, seq_batch.mask)
def forward(self, encoder_output, train_decoder_input): """ Args: encoder_output (EncoderOutput) train_decoder_input (TrainDecoderInput) Returns: rnn_states (list[RNNState]) total_loss (Variable): a scalar loss """ batch_size, _ = train_decoder_input.input_words.mask.size() rnn_state = self.decoder_cell.initialize(batch_size) input_word_embeds = self.token_embedder.embed_seq_batch( train_decoder_input.input_words) input_embed_list = input_word_embeds.split() target_word_list = train_decoder_input.target_words.split() loss_list = [] rnn_states = [] for t, (x, target_word) in enumerate( izip(input_embed_list, target_word_list)): # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable # update rnn state rnn_input = self.rnn_context_combiner(encoder_output, x.values) decoder_cell_output = self.decoder_cell(rnn_state, rnn_input, x.mask) rnn_state = decoder_cell_output.rnn_state rnn_states.append(rnn_state) # compute loss loss = decoder_cell_output.loss( target_word.values) # (batch_size,) loss_list.append(SequenceBatchElement(loss, x.mask)) losses = SequenceBatch.cat( loss_list) # (batch_size, target_seq_length) # sum losses across time, accounting for mask per_instance_losses = SequenceBatch.reduce_sum(losses) # (batch_size,) return rnn_states, per_instance_losses
def embed_seq_batch(self, seq_batch): """Embed elements of a SequenceBatch. Args: seq_batch (SequenceBatch) Returns: SequenceBatch """ return SequenceBatch(self._embedding(seq_batch.values), seq_batch.mask)