Beispiel #1
0
    def forward(self, insert_embeds, insert_embeds_exact, delete_embeds, delete_embeds_exact, draw_samples = False, draw_p = False):
        """Create agenda vector.

        Args:
            insert_embeds (SequenceBatch): of shape (batch_size, max_edits, word_dim)
            insert_embeds_exact (SequenceBatch): of shape (batch_size, max_edits, word_dim)
            delete_embeds (SequenceBatch): of shape (batch_size, max_edits, word_dim)
            delete_embeds_exact (SequenceBatch): of shape (batch_size, max_edits, word_dim)
            draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time.

        Returns:
            edit_embed (Variable): of shape (batch_size, edit_vec_cim)
        """
        insert_embed = SequenceBatch.reduce_sum(insert_embeds)  # (batch_size, word_dim)
        insert_embed += SequenceBatch.reduce_sum(insert_embeds_exact)  # (batch_size, word_dim)
        delete_embed = SequenceBatch.reduce_sum(delete_embeds)  # (batch_size, word_dim)
        delete_embed += SequenceBatch.reduce_sum(delete_embeds_exact)  # (batch_size, word_dim)
        insert_set = self.linear_prenoise(insert_embed)
        delete_set = self.linear_prenoise(delete_embed)
        combined_map = torch.cat([insert_set, delete_set], 1)
        if draw_samples:
            if draw_p:
                batch_size, edit_dim = combined_map.size()
                combined_map = self.draw_p_noise(batch_size, edit_dim)
            else:
                combined_map = self.sample_vMF(combined_map, self.noise_scaler)
        edit_embed = combined_map
        return edit_embed
    def test_reduce_sum(self, some_seq_batch):
        result = SequenceBatch.reduce_sum(some_seq_batch)

        assert_tensor_equal(result, [
            [5, 7],
            [0, 4],
            [0, 0],
        ])
    def loss(self, encoder_output, train_decoder_input):
        _, _, losses = self(encoder_output, train_decoder_input)

        # sum losses across time, accounting for mask
        per_instance_losses = SequenceBatch.reduce_sum(losses)  # (batch_size,)

        # average across instances
        total_loss = torch.mean(per_instance_losses)

        return total_loss, losses
Beispiel #4
0
    def forward(self, dom_elements, alignment_fields):
        """Computes the alignments. An element aligns iff elem.text
        in utterance and elem.text != ""

        Args:
            dom_elements (list[list[DOMElement]]): batch of set of DOM
                elements (padded to be unragged)
            alignment_fields (list[Fields]): batch of fields. Alignments
                computed with the values of the fields.

        Returns:
            Variable[FloatTensor]: batch x num_elems x embed_dim
                The aligned embeddings per DOM element
        """
        batch_size = len(dom_elements)
        assert batch_size > 0
        num_dom_elems = len(dom_elements[0])
        assert num_dom_elems > 0

        # mask batch_size x num_dom_elems x num_buckets
        alignments = np.zeros(
            (batch_size, num_dom_elems, self._num_buckets)).astype(np.float32)

        # Calculate the alignment matrix between elems and fields
        for batch_idx in xrange(len(dom_elements)):
            for dom_idx, dom in enumerate(dom_elements[batch_idx]):
                keys = alignment_fields[batch_idx].keys
                vals = alignment_fields[batch_idx].values
                for key, val in zip(keys, vals):
                    if dom.text and dom.text in val:
                        align_idx = self._keys2index.word2index(key)
                        alignments[batch_idx, dom_idx, align_idx] = 1.

        # Flatten alignments for SequenceBatch
        # (batch * num_dom_elems) x num_buckets
        alignments = GPUVariable(
            torch.from_numpy(
                alignments.reshape(
                    (batch_size * num_dom_elems, self._num_buckets))))

        # (batch * num_dom_elems) x num_buckets x embed_dim
        expanded_alignment_embeds = self._alignment_embeds.expand(
            batch_size * num_dom_elems, self._num_buckets, self.embed_dim)
        alignment_seq_batch = SequenceBatch(expanded_alignment_embeds,
                                            alignments,
                                            left_justify=False)

        # (batch * num_dom_elems) x alignment_embed_dim
        alignment_embeds = SequenceBatch.reduce_sum(alignment_seq_batch)
        return alignment_embeds.view(batch_size, num_dom_elems, self.embed_dim)
Beispiel #5
0
    def forward(self, encoder_output, train_decoder_input):
        """

        Args:
            encoder_output (EncoderOutput)
            train_decoder_input (TrainDecoderInput)

        Returns:
            rnn_states (list[RNNState])
            total_loss (Variable): a scalar loss
        """
        batch_size, _ = train_decoder_input.input_words.mask.size()
        rnn_state = self.decoder_cell.initialize(batch_size)

        input_word_embeds = self.token_embedder.embed_seq_batch(
            train_decoder_input.input_words)

        input_embed_list = input_word_embeds.split()
        target_word_list = train_decoder_input.target_words.split()

        loss_list = []
        rnn_states = []
        for t, (x, target_word) in enumerate(
                izip(input_embed_list, target_word_list)):
            # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable

            # update rnn state
            rnn_input = self.rnn_context_combiner(encoder_output, x.values)
            decoder_cell_output = self.decoder_cell(rnn_state, rnn_input,
                                                    x.mask)
            rnn_state = decoder_cell_output.rnn_state
            rnn_states.append(rnn_state)

            # compute loss
            loss = decoder_cell_output.loss(
                target_word.values)  # (batch_size,)
            loss_list.append(SequenceBatchElement(loss, x.mask))

        losses = SequenceBatch.cat(
            loss_list)  # (batch_size, target_seq_length)

        # sum losses across time, accounting for mask
        per_instance_losses = SequenceBatch.reduce_sum(losses)  # (batch_size,)
        return rnn_states, per_instance_losses
Beispiel #6
0
    def embed(self, sequences):
        for seq in sequences:
            if len(seq) == 0:
                raise ValueError("Cannot embed empty sequence.")

        token_indices = SequenceBatch.from_sequences(sequences, self.vocab, min_seq_length=1)
        token_embeds = self.token_embedder.embed_seq_batch(token_indices)  # SequenceBatch of size (batch_size, max_seq_length, word_dim)
        if self.pool == 'sum':
            pooled_token_embeds = SequenceBatch.reduce_sum(token_embeds)  # (batch_size, word_dim)
        elif self.pool == 'mean':
            pooled_token_embeds = SequenceBatch.reduce_mean(token_embeds)  # (batch_size, word_dim)
        elif self.pool == 'max':
            pooled_token_embeds = SequenceBatch.reduce_max(token_embeds)  # (batch_size, word_dim)
        else:
            raise ValueError(self.pool)

        seq_embeds = self.transform(pooled_token_embeds)  # (batch_size, embed_dim)
        assert seq_embeds.size()[1] == self.embed_dim

        return seq_embeds
Beispiel #7
0
    def forward(self, dom_elems, base_dom_embeds):
        """Embeds a batch of DOMElement sequences by mixing base embeddings on
        notions of neighbors.

        Args:
            dom_elems (list[list[DOMElement]]): a batch of DOMElement
                sequences to embed. All sequences must be padded to have the
                same number of DOM elements.
            base_dom_embeds (Variable[FloatTensor]):
                batch_size, num_dom_elems, base_dom_embed_dim

        Returns:
            dom_embeds (Variable[FloatTensor]): of shape (batch_size,
                num_dom_elems, embed_dim)
        """
        batch_size, num_dom_elems, embed_dim = base_dom_embeds.size()

        # flatten, for easier processing
        base_dom_embeds_flat = base_dom_embeds.view(batch_size * num_dom_elems,
                                                    embed_dim)

        # list of length: batch_size * num_dom_elems
        dom_elems_flat = flatten(dom_elems)
        assert len(dom_elems_flat) == batch_size * num_dom_elems

        # DOM neighbors whose LCA goes from depth 3 to 6 (root is depth 1)
        dom_neighbor_embeds = []
        for k in xrange(self._lca_depth_start, self._lca_depth_end):
            is_neighbor_fn = lambda elem1, elem2: (N.is_depth_k_lca_neighbor(
                elem1, elem2, k, self._lca_cache) and N.is_text_neighbor(
                    elem1, elem2))
            dom_neighbor_indices = self._get_neighbor_indices(
                dom_elems, is_neighbor_fn)

            # TODO: reduce_max
            # (batch_size * num_dom_elems, embed_dim)
            neighbor_embedding = SequenceBatch.reduce_sum(
                SequenceBatch.embed(dom_neighbor_indices,
                                    base_dom_embeds_flat))
            # TODO: Batch these projections? For performance
            projected_neighbor_embedding = self._dom_neighbor_projection(
                neighbor_embedding)
            dom_neighbor_embeds.append(projected_neighbor_embedding)

        # (batch_size * num_dom_elems, lca_range * (base_embed_dim /
        # lca_range))
        dom_neighbor_embeds = torch.cat(dom_neighbor_embeds, 1)

        # SequenceBatch of shape (batch_size * num_dom_elems, max_neighbors)
        pixel_neighbor_indices = self._get_neighbor_indices(
            dom_elems, lambda elem1, elem2:
            (N.is_pixel_neighbor(elem1, elem2) and N.is_text_neighbor(
                elem1, elem2)))

        # SequenceBatch of shape
        # (batch_size * num_dom_elems, max_neighbors, embed_dim)
        pixel_neighbor_embeds = SequenceBatch.embed(pixel_neighbor_indices,
                                                    base_dom_embeds_flat)

        # TODO(kelvin): switch to reduce_max
        # (batch_size * num_dom_elems, embed_dim)
        pixel_neighbor_embeds_flat = SequenceBatch.reduce_mean(
            pixel_neighbor_embeds, allow_empty=True)

        dom_embeds_flat = torch.cat([
            base_dom_embeds_flat, pixel_neighbor_embeds_flat,
            dom_neighbor_embeds
        ], 1)
        dom_embeds = dom_embeds_flat.view(batch_size, num_dom_elems,
                                          self.embed_dim)

        return dom_embeds