def test_embed(self): sequences = [ [], [1, 2, 3], [3, 3], [2] ] vocab = SimpleVocab([0, 1, 2, 3, 4]) indices = SequenceBatch.from_sequences(sequences, vocab) embeds = GPUVariable(torch.FloatTensor([ [0, 0], [2, 2], # 1 [3, 4], # 2 [-10, 1], # 3 [11, -1] # 4 ])) embedded = SequenceBatch.embed(indices, embeds) correct = np.array([ [[0, 0], [0, 0], [0, 0]], [[2, 2], [3, 4], [-10, 1]], [[-10, 1], [-10, 1], [0, 0]], [[3, 4], [0, 0], [0, 0]] ], dtype=np.float32) assert_tensor_equal(embedded.values, correct)
def forward(self, dom_elems, base_dom_embeds): """Embeds a batch of DOMElement sequences by mixing base embeddings on notions of neighbors. Args: dom_elems (list[list[DOMElement]]): a batch of DOMElement sequences to embed. All sequences must be padded to have the same number of DOM elements. base_dom_embeds (Variable[FloatTensor]): batch_size, num_dom_elems, base_dom_embed_dim Returns: dom_embeds (Variable[FloatTensor]): of shape (batch_size, num_dom_elems, embed_dim) """ batch_size, num_dom_elems, embed_dim = base_dom_embeds.size() # flatten, for easier processing base_dom_embeds_flat = base_dom_embeds.view(batch_size * num_dom_elems, embed_dim) # list of length: batch_size * num_dom_elems dom_elems_flat = flatten(dom_elems) assert len(dom_elems_flat) == batch_size * num_dom_elems # DOM neighbors whose LCA goes from depth 3 to 6 (root is depth 1) dom_neighbor_embeds = [] for k in xrange(self._lca_depth_start, self._lca_depth_end): is_neighbor_fn = lambda elem1, elem2: (N.is_depth_k_lca_neighbor( elem1, elem2, k, self._lca_cache) and N.is_text_neighbor( elem1, elem2)) dom_neighbor_indices = self._get_neighbor_indices( dom_elems, is_neighbor_fn) # TODO: reduce_max # (batch_size * num_dom_elems, embed_dim) neighbor_embedding = SequenceBatch.reduce_sum( SequenceBatch.embed(dom_neighbor_indices, base_dom_embeds_flat)) # TODO: Batch these projections? For performance projected_neighbor_embedding = self._dom_neighbor_projection( neighbor_embedding) dom_neighbor_embeds.append(projected_neighbor_embedding) # (batch_size * num_dom_elems, lca_range * (base_embed_dim / # lca_range)) dom_neighbor_embeds = torch.cat(dom_neighbor_embeds, 1) # SequenceBatch of shape (batch_size * num_dom_elems, max_neighbors) pixel_neighbor_indices = self._get_neighbor_indices( dom_elems, lambda elem1, elem2: (N.is_pixel_neighbor(elem1, elem2) and N.is_text_neighbor( elem1, elem2))) # SequenceBatch of shape # (batch_size * num_dom_elems, max_neighbors, embed_dim) pixel_neighbor_embeds = SequenceBatch.embed(pixel_neighbor_indices, base_dom_embeds_flat) # TODO(kelvin): switch to reduce_max # (batch_size * num_dom_elems, embed_dim) pixel_neighbor_embeds_flat = SequenceBatch.reduce_mean( pixel_neighbor_embeds, allow_empty=True) dom_embeds_flat = torch.cat([ base_dom_embeds_flat, pixel_neighbor_embeds_flat, dom_neighbor_embeds ], 1) dom_embeds = dom_embeds_flat.view(batch_size, num_dom_elems, self.embed_dim) return dom_embeds