def test_reduce_mean(self, some_seq_batch): result = SequenceBatch.reduce_mean(some_seq_batch, allow_empty=True) assert_tensor_equal(result, [[2.5, 3.5], [0, 4], [0, 0]]) with pytest.raises(ValueError): SequenceBatch.reduce_mean(some_seq_batch, allow_empty=False)
def embed(self, sequences): for seq in sequences: if len(seq) == 0: raise ValueError("Cannot embed empty sequence.") token_indices = SequenceBatch.from_sequences(sequences, self.vocab, min_seq_length=1) token_embeds = self.token_embedder.embed_seq_batch(token_indices) # SequenceBatch of size (batch_size, max_seq_length, word_dim) if self.pool == 'sum': pooled_token_embeds = SequenceBatch.reduce_sum(token_embeds) # (batch_size, word_dim) elif self.pool == 'mean': pooled_token_embeds = SequenceBatch.reduce_mean(token_embeds) # (batch_size, word_dim) elif self.pool == 'max': pooled_token_embeds = SequenceBatch.reduce_max(token_embeds) # (batch_size, word_dim) else: raise ValueError(self.pool) seq_embeds = self.transform(pooled_token_embeds) # (batch_size, embed_dim) assert seq_embeds.size()[1] == self.embed_dim return seq_embeds
def forward(self, utterances): """Embeds an utterances. Args: utterances (list[list[str]]): list[str] is a list of tokens forming a sentence. list[list[str]] is batch of sentences. Returns: Tensor: batch x word_embed_dim (average of word vectors) """ # Cut to max_words + look up indices utterances = [ utterance[:self._max_words] + [EOS] for utterance in utterances ] token_indices = SequenceBatch.from_sequences( utterances, self._token_embedder.vocab) # batch x seq_len x token_embed_dim token_embeds = self._token_embedder.embed_seq_batch(token_indices) # batch x token_embed_dim averaged = SequenceBatch.reduce_mean(token_embeds) return averaged
def forward(self, dom_elems, base_dom_embeds): """Embeds a batch of DOMElement sequences by mixing base embeddings on notions of neighbors. Args: dom_elems (list[list[DOMElement]]): a batch of DOMElement sequences to embed. All sequences must be padded to have the same number of DOM elements. base_dom_embeds (Variable[FloatTensor]): batch_size, num_dom_elems, base_dom_embed_dim Returns: dom_embeds (Variable[FloatTensor]): of shape (batch_size, num_dom_elems, embed_dim) """ batch_size, num_dom_elems, embed_dim = base_dom_embeds.size() # flatten, for easier processing base_dom_embeds_flat = base_dom_embeds.view(batch_size * num_dom_elems, embed_dim) # list of length: batch_size * num_dom_elems dom_elems_flat = flatten(dom_elems) assert len(dom_elems_flat) == batch_size * num_dom_elems # DOM neighbors whose LCA goes from depth 3 to 6 (root is depth 1) dom_neighbor_embeds = [] for k in xrange(self._lca_depth_start, self._lca_depth_end): is_neighbor_fn = lambda elem1, elem2: (N.is_depth_k_lca_neighbor( elem1, elem2, k, self._lca_cache) and N.is_text_neighbor( elem1, elem2)) dom_neighbor_indices = self._get_neighbor_indices( dom_elems, is_neighbor_fn) # TODO: reduce_max # (batch_size * num_dom_elems, embed_dim) neighbor_embedding = SequenceBatch.reduce_sum( SequenceBatch.embed(dom_neighbor_indices, base_dom_embeds_flat)) # TODO: Batch these projections? For performance projected_neighbor_embedding = self._dom_neighbor_projection( neighbor_embedding) dom_neighbor_embeds.append(projected_neighbor_embedding) # (batch_size * num_dom_elems, lca_range * (base_embed_dim / # lca_range)) dom_neighbor_embeds = torch.cat(dom_neighbor_embeds, 1) # SequenceBatch of shape (batch_size * num_dom_elems, max_neighbors) pixel_neighbor_indices = self._get_neighbor_indices( dom_elems, lambda elem1, elem2: (N.is_pixel_neighbor(elem1, elem2) and N.is_text_neighbor( elem1, elem2))) # SequenceBatch of shape # (batch_size * num_dom_elems, max_neighbors, embed_dim) pixel_neighbor_embeds = SequenceBatch.embed(pixel_neighbor_indices, base_dom_embeds_flat) # TODO(kelvin): switch to reduce_max # (batch_size * num_dom_elems, embed_dim) pixel_neighbor_embeds_flat = SequenceBatch.reduce_mean( pixel_neighbor_embeds, allow_empty=True) dom_embeds_flat = torch.cat([ base_dom_embeds_flat, pixel_neighbor_embeds_flat, dom_neighbor_embeds ], 1) dom_embeds = dom_embeds_flat.view(batch_size, num_dom_elems, self.embed_dim) return dom_embeds