예제 #1
0
    def train_step(self, examples):
        # sample a beam of logical forms for each example
        beams = self.predictions(examples, train=True)

        all_cases = []  # a list of ParseCases to give to ParseModel
        all_case_weights = [] # the weights associated with the cases
        for example, paths in zip(examples, beams):
            case_weights = self._case_weighter(paths, example)
            case_weights = flatten(case_weights)
            cases = flatten(paths)
            assert len(case_weights) == sum(len(p) for p in paths)

            all_cases.extend(cases)
            all_case_weights.extend(case_weights)

        # for efficiency, prune cases with weight 0
        cases_to_reinforce = []
        weights_to_reinforce = []
        for case, weight in zip(all_cases, all_case_weights):
            if weight != 0:
                cases_to_reinforce.append(case)
                weights_to_reinforce.append(weight)

        # update value function
        vf_examples = []
        for example, paths in zip(examples, beams):
            vf_examples.extend(ValueFunctionExample.examples_from_paths(paths, example))
        self._value_function.train_step(vf_examples)

        # update parse model
        self._parse_model.train_step(
                cases_to_reinforce, weights_to_reinforce, caching=False)
예제 #2
0
    def train_step(self, examples):
        # sample a beam of logical forms for each example
        beams = self.predictions(examples, train=True)

        all_cases = []  # a list of ParseCases to give to ParseModel
        all_case_weights = []  # the weights associated with the cases
        for example, paths in zip(examples, beams):
            case_weights = self._case_weighter(paths, example)
            case_weights = flatten(case_weights)
            cases = flatten(paths)
            assert len(case_weights) == sum(len(p) for p in paths)

            all_cases.extend(cases)
            all_case_weights.extend(case_weights)

        # for efficiency, prune cases with weight 0
        cases_to_reinforce = []
        weights_to_reinforce = []
        for case, weight in zip(all_cases, all_case_weights):
            if weight != 0:
                cases_to_reinforce.append(case)
                weights_to_reinforce.append(weight)

        # update value function
        vf_examples = []
        for example, paths in zip(examples, beams):
            vf_examples.extend(
                ValueFunctionExample.examples_from_paths(paths, example))
        self._value_function.train_step(vf_examples)

        # update parse model
        self._parse_model.train_step(cases_to_reinforce,
                                     weights_to_reinforce,
                                     caching=False)
예제 #3
0
    def _query_embeds(self, states, query_entries):
        """Given a batch of states, embed the keys and values of each state's
        query.

        Args:
            states (list[MiniWoBState])

        Returns:
            entry_embeds (SequenceBatch): batch x num_keys x (2 * embed_dim)
                the keys and values concatenated
        """
        fields_batch = [state.fields for state in states]

        # list[list[list[unicode]]] (batch x num_keys x key length)
        values_batch = [[word_tokenize(value) for value in fields.values] for
                        fields in fields_batch]
        keys_batch = [[word_tokenize(key) for key in fields.keys] for fields
                      in fields_batch]

        # Pad
        batch_size = len(fields_batch)
        max_num_fields = max(len(values) for values in values_batch)
        max_num_fields = max(max_num_fields, 1)  # Ensure non-empty
        mask = torch.ones(batch_size, max_num_fields)
        assert len(keys_batch) == len(values_batch) == len(mask)
        for keys, values, submask in zip(keys_batch, values_batch, mask):
            assert len(keys) == len(values)
            if len(keys) < max_num_fields:
                submask[len(keys):] = 0.
                keys.extend(
                    [[UtteranceVocab.PAD] for _ in xrange(
                        max_num_fields - len(keys))])
                values.extend(
                    [[UtteranceVocab.PAD] for _ in xrange(
                        max_num_fields - len(values))])

        # Flatten to list[list[unicode]] (batch * num_keys) x key length
        keys_batch = flatten(keys_batch)
        values_batch = flatten(values_batch)

        # Embed and mask (batch * num_keys) x embed_dim
        key_embeds, _ = self._utterance_embedder(keys_batch)
        key_embeds = key_embeds.view(
                batch_size, max_num_fields, self._utterance_embedder.embed_dim)
        value_embeds, _ = self._utterance_embedder(values_batch)
        value_embeds = value_embeds.view(
                batch_size, max_num_fields, self._utterance_embedder.embed_dim)
        key_embeds = SequenceBatch(key_embeds, GPUVariable(mask))
        value_embeds = SequenceBatch(value_embeds, GPUVariable(mask))

        entry_embed_values = torch.cat(
                [key_embeds.values, value_embeds.values], 2)
        entry_embeds = SequenceBatch(entry_embed_values, key_embeds.mask)
        return entry_embeds
예제 #4
0
    def __init__(self, hard_copy_vocab, source_tokens, token_lengths):
        """HardCopyDynamicVocab.
        
        NOTE: HardCopyDynamicVocab is blind to casing.
        
        Args:
            hard_copy_vocab (HardCopyVocab)
            source_tokens (list[list[unicode]]) -- outer list is over different input channels, inner list is a sentence
        """
        num_channels = len(source_tokens)

        tok_cum = [0]+np.cumsum(token_lengths).tolist()
        tokens_by_channel = [hard_copy_vocab.copy_tokens[tok_cum[i]:tok_cum[i+1]] for i in range(num_channels)]
        self.tokens_by_channel = tokens_by_channel

        flatten = lambda l: [item for sublist in l for item in sublist]

        copy_words = [sentence_tokens + [hard_copy_vocab.STOP] for sentence_tokens in source_tokens]
        copy_pairs = flatten([zip(copy_words[i], tokens_by_channel[i]) for i in range(num_channels)])
        # pair each copy word with a copy token
        # note that zip only zips up to the shorter of the two (desired behavior)

        #old snippet below
        #copy_words = [self.get_copy_words(sentence_tokens, hard_copy_vocab) for sentence_tokens in source_tokens]
        #copy_pairs = zip(copy_words, hard_copy_vocab.copy_tokens)
        self.word_to_copy_token = {w: t for w, t in copy_pairs[::-1]}
        self.copy_token_to_word = {t: w for w, t in copy_pairs[::-1]}

        # map from word to index
        self.base_vocab = hard_copy_vocab
예제 #5
0
    def __init__(self, whitelists, vocab, word_to_forms, always_allowed=None):
        """Modify extension probs such that only words appearing in the whitelist (or have a variant in the whitelist) are allowed.
        
        Args:
            whitelists (list[list[unicode]]): a batch of whitelists, one per example.
                Each whitelist is a list of permitted words.
            vocab (HardCopyVocab)
            word_to_forms (Callable[unicode, list[unicode]): a function mapping words to forms
            always_allowed (list[unicode]): a list of words that is allowed in any example
        """
        # importantly, vocab should NOT be a HardCopyDynamicVocab
        assert isinstance(vocab, HardCopyVocab)

        if always_allowed is None:
            # always allow STOP, copy tokens, and various stop words
            always_allowed = []
            always_allowed.append(vocab.STOP)
            always_allowed.extend(vocab.copy_tokens)
            always_allowed.extend(STOPWORDS)  # all lower case

        # so far, these should all be unique
        always_allowed_indices = [vocab.word2index(w) for w in always_allowed if w in vocab]

        # precompute actual allowed words
        all_whitelist_words = set(flatten(whitelists))
        word_to_form_indices = self._word_to_form_indices(all_whitelist_words, vocab, word_to_forms)

        allowed_indices = []
        for whitelist in whitelists:
            indices = self._whitelist_to_indices(whitelist, word_to_form_indices, always_allowed_indices)
            allowed_indices.append(indices)
            # Returned indices may contain some duplicates.

        self.allowed_indices = allowed_indices
예제 #6
0
    def _get_neighbor_indices(self, dom_elements, is_neighbor):
        """Compute neighbor indices.

        Args:
            dom_elements (list[DOMElement]): may include PAD elements
            is_neighbor (Callable: DOMElement x DOMElement --> bool): True if
                two DOM elements are neighbors of each other, otherwise False

        Returns:
            SequenceBatch: of shape (total_dom_elems, max_neighbors)
        """
        dom_element_ids = [id(e) for e in flatten(dom_elements)]
        dom_element_ids_set = set(dom_element_ids)
        vocab = SuperSimpleVocab(dom_element_ids)

        neighbors_batch = []
        for dom_batch in dom_elements:
            for dom_elem in dom_batch:
                # Optimization: no DOM PAD has neighbors
                if isinstance(dom_elem, DOMElementPAD):
                    neighbors = []
                else:
                    neighbors = []
                    for neighbor in dom_batch:
                        if is_neighbor(dom_elem, neighbor):
                            neighbors.append(id(neighbor))

                neighbors_batch.append(neighbors)

        neighbor_indices = SequenceBatch.from_sequences(neighbors_batch,
                                                        vocab,
                                                        min_seq_length=1)
        return neighbor_indices
예제 #7
0
def decoder_inputs_and_outputs(target_words, base_vocab):
    """Convert a sequence of tokens into a decoder input seq and output seq.
    
    Args:
        target_words (list[unicode])
        base_vocab (Vocab)

    Returns:
        input_words (list[unicode])
        output_words (list[unicode])
    """
    # prepend with <start> token
    input_words = [base_vocab.START] + flatten(target_words)
    # append with <stop> token
    output_words = flatten(target_words) + [base_vocab.STOP]
    return input_words, output_words
예제 #8
0
    def _edit_batch(self, examples, max_seq_length, beam_size, constrain_vocab):
        # should only run in evaluation mode
        assert not self.training

        input_words, output_words = self._batch_editor_examples(examples)
        base_vocab = self.base_vocab
        dynamic_vocabs = self._compute_dynamic_vocabs(input_words, base_vocab)
        dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(self.base_source_token_embedder, dynamic_vocabs, base_vocab)

        encoder_input = self.encoder.preprocess(input_words, output_words, dynamic_token_embedder, volatile=True)
        encoder_output, _ = self.encoder(encoder_input)

        extension_probs_modifiers = []

        if constrain_vocab:
            whitelists = [flatten(ex.input_words) for ex in examples]  # will contain duplicates, that's ok
            vocab_constrainer = LexicalWhitelister(whitelists, self.base_vocab, word_to_forms)
            extension_probs_modifiers.append(vocab_constrainer)

        beams, decoder_traces = self.test_decoder_beam.decode(examples, encoder_output,
            beam_size=beam_size, max_seq_length=max_seq_length,
            extension_probs_modifiers=extension_probs_modifiers
        )

        # replace copy tokens in predictions with actual words, modifying beams in-place
        for beam, dyna_vocab in izip(beams, dynamic_vocabs):
            copy_to_word = dyna_vocab.copy_token_to_word
            for i, seq in enumerate(beam):
                beam[i] = [copy_to_word.get(w, w) for w in seq]

        return beams, [EditTrace(ex, d_trace.beam_traces[-1], dyna_vocab)
                       for ex, d_trace, dyna_vocab in izip(examples, decoder_traces, dynamic_vocabs)]
예제 #9
0
    def update_from_demonstrations(self, demonstrations, take_grad_step):
        """Calculates the cross-entropy loss from a batch of demonstrations.

        Args:
            demonstrations (EpisodeGraph)

        Returns:
            loss (Variable[FloatTensor])
            take_grad_step (Callable): takes a loss Variable and takes a
                gradient step on the loss
        """
        experiences = flatten(episode_graph.to_experiences()
                for episode_graph in demonstrations)
        scored_experiences = self._score_experiences(experiences)
        loss = -sum(exp.log_prob for exp in scored_experiences) / len(demonstrations)
        self._clear_cache()
        take_grad_step(loss)
예제 #10
0
    def _score_episodes(self, episodes):
        """Score all the experiences in each of the episodes.

        Args:
            episodes (list[Episodes])

        Returns:
            scored_episodes (list[ScoredEpisode])
        """
        # score all experiences in a batch
        experiences = flatten(episodes)
        scored_experiences = self._score_experiences(experiences)

        # convert experiences back into episodes
        scored_episodes = []
        scored_experiences_reversed = list(reversed(scored_experiences))
        for ep in episodes:
            scored_ep = Episode()
            for _ in range(len(ep)):
                scored_ep.append(scored_experiences_reversed.pop())
            scored_episodes.append(scored_ep)

        return scored_episodes
예제 #11
0
 def beam_duplicate(self, beam_size):
     duplicated_dynamic_vocabs = flatten([dyna_vocab] * beam_size for dyna_vocab in self.dynamic_vocabs)
     return DynamicMultiVocabTokenEmbedder(self.base_embedder, duplicated_dynamic_vocabs, self.base_vocab)
예제 #12
0
    def forward(self, dom_elems, base_dom_embeds):
        """Embeds a batch of DOMElement sequences by mixing base embeddings on
        notions of neighbors.

        Args:
            dom_elems (list[list[DOMElement]]): a batch of DOMElement
                sequences to embed. All sequences must be padded to have the
                same number of DOM elements.
            base_dom_embeds (Variable[FloatTensor]):
                batch_size, num_dom_elems, base_dom_embed_dim

        Returns:
            dom_embeds (Variable[FloatTensor]): of shape (batch_size,
                num_dom_elems, embed_dim)
        """
        batch_size, num_dom_elems, embed_dim = base_dom_embeds.size()

        # flatten, for easier processing
        base_dom_embeds_flat = base_dom_embeds.view(batch_size * num_dom_elems,
                                                    embed_dim)

        # list of length: batch_size * num_dom_elems
        dom_elems_flat = flatten(dom_elems)
        assert len(dom_elems_flat) == batch_size * num_dom_elems

        # DOM neighbors whose LCA goes from depth 3 to 6 (root is depth 1)
        dom_neighbor_embeds = []
        for k in xrange(self._lca_depth_start, self._lca_depth_end):
            is_neighbor_fn = lambda elem1, elem2: (N.is_depth_k_lca_neighbor(
                elem1, elem2, k, self._lca_cache) and N.is_text_neighbor(
                    elem1, elem2))
            dom_neighbor_indices = self._get_neighbor_indices(
                dom_elems, is_neighbor_fn)

            # TODO: reduce_max
            # (batch_size * num_dom_elems, embed_dim)
            neighbor_embedding = SequenceBatch.reduce_sum(
                SequenceBatch.embed(dom_neighbor_indices,
                                    base_dom_embeds_flat))
            # TODO: Batch these projections? For performance
            projected_neighbor_embedding = self._dom_neighbor_projection(
                neighbor_embedding)
            dom_neighbor_embeds.append(projected_neighbor_embedding)

        # (batch_size * num_dom_elems, lca_range * (base_embed_dim /
        # lca_range))
        dom_neighbor_embeds = torch.cat(dom_neighbor_embeds, 1)

        # SequenceBatch of shape (batch_size * num_dom_elems, max_neighbors)
        pixel_neighbor_indices = self._get_neighbor_indices(
            dom_elems, lambda elem1, elem2:
            (N.is_pixel_neighbor(elem1, elem2) and N.is_text_neighbor(
                elem1, elem2)))

        # SequenceBatch of shape
        # (batch_size * num_dom_elems, max_neighbors, embed_dim)
        pixel_neighbor_embeds = SequenceBatch.embed(pixel_neighbor_indices,
                                                    base_dom_embeds_flat)

        # TODO(kelvin): switch to reduce_max
        # (batch_size * num_dom_elems, embed_dim)
        pixel_neighbor_embeds_flat = SequenceBatch.reduce_mean(
            pixel_neighbor_embeds, allow_empty=True)

        dom_embeds_flat = torch.cat([
            base_dom_embeds_flat, pixel_neighbor_embeds_flat,
            dom_neighbor_embeds
        ], 1)
        dom_embeds = dom_embeds_flat.view(batch_size, num_dom_elems,
                                          self.embed_dim)

        return dom_embeds
예제 #13
0
    def forward(self, dom_elem):
        """Embeds a batch of DOMElements.

        Args:
            dom_elem (list[list[DOMElement]]): batch of list of DOM. Each
                batch must already be padded to have the same number of DOM
                elements.

        Returns:
            Variable(FloatTensor): batch x num_dom_elems x embed_dim
        """
        # Check that the batches are rectangular
        for dom_list in dom_elem:
            assert len(dom_list) == len(dom_elem[0])

        num_dom_elems = len(dom_elem[0])
        dom_elem = flatten(dom_elem)

        # (batch * max_dom_num) x lstm_dim
        text_embeddings = []
        for batch in as_batches(dom_elem, 100):
            final_states, combined_states = self._utterance_embedder(
                [word_tokenize(dom.text) for dom in batch])
            text_embeddings.append(final_states)
        text_embeddings = torch.cat(text_embeddings, 0)

        # (batch * max_dom_num) x tag_embed_dim
        tag_embeddings = self._tag_embedder.embed_tokens(
            [dom.tag for dom in dom_elem])

        value_embeddings = self._value_embedder.embed_tokens(
            [bool(dom.value) for dom in dom_elem])

        tampered_embeddings = self._tampered_embedder.embed_tokens(
            [dom.tampered for dom in dom_elem])

        class_embeddings = self._classes_embedder.embed_tokens(
            [dom.classes for dom in dom_elem])

        # (batch * max_dom_num) x 4
        fg_colors = [
            GPUVariable(torch.FloatTensor(elem.fg_color)) for elem in dom_elem
        ]
        fg_colors = torch.stack(fg_colors)
        bg_colors = [
            GPUVariable(torch.FloatTensor(elem.bg_color)) for elem in dom_elem
        ]
        bg_colors = torch.stack(bg_colors)

        # (batch * max_dom_num) x 2
        coords = [
            GPUVariable(
                torch.FloatTensor((float(elem.left) / positions.IMAGE_COLS,
                                   float(elem.top) / positions.IMAGE_ROWS)))
            for elem in dom_elem
        ]
        coords = torch.stack(coords)

        # (batch * max_dom_num) * dom_embed_dim
        dom_embeddings = torch.cat(
            (text_embeddings, tag_embeddings, value_embeddings,
             tampered_embeddings, class_embeddings, coords, fg_colors,
             bg_colors),
            dim=1)

        # batch x max_dom_num x dom_embed_dim
        return dom_embeddings.view(-1, num_dom_elems, self.embed_dim)
    def advance(self, terminated, beams, empirical_distributions):
        """Advance a batch of beams.

        Args:
            terminated (list[set(ParsePath)]): a batch of all the
                terminated paths found so far for each beam.
            beams (list[list[ParsePath]]): a batch of beams.
                All paths on all beams have the same length (all
                should be unterminated)
            empirical_distributions (list[list[float]]): a batch of
                distributions over the corresponding beams.

        Returns:
            list[set[ParsePath]]: a batch of terminated beams
                (in the same order as the input beams)
            list[list[ParsePath]]: a batch of new beams all extended
                by one time step
            list[list[float]]: the new empirical distributions over these
                particles
        """
        # nothing on the beams should be terminated
        # terminated paths should be in the terminated set
        for beam in beams:
            for path in beam:
                assert not path.terminated

        path_extensions = [[path.extend() for path in beam] for beam in beams]

        # for exploration, use a parser which pretends like every utterance
        # is the first utterance it is seeing
        ignore_previous_utterances = \
            self._config.independent_utterance_exploration

        # Use the ParseModel to score
        self._decoder.parse_model.score(flatten(path_extensions),
                                        ignore_previous_utterances,
                                        self._decoder.caching)

        new_beams = []
        new_distributions = []
        gamma = self._config.exploration_gamma
        for terminated_set, cases, distribution in zip(
                terminated, path_extensions, empirical_distributions):

            new_path_log_probs = []
            paths_to_sample_from = []

            for case, path_prob in zip(cases, distribution):
                for continuation in case.valid_continuations(
                        self._decoder.path_checker):
                    # Add all the terminated paths
                    if continuation.terminated:
                        terminated_set.add(continuation)
                    else:
                        # Sample from unterminated paths
                        new_path_log_probs.append(
                                gamma * continuation[-1].log_prob +
                                np.log(path_prob))
                        paths_to_sample_from.append(continuation)

            if len(paths_to_sample_from) == 0:
                new_beams.append([])
                new_distributions.append([])
                continue

            new_path_probs = softmax(new_path_log_probs)

            new_particles, new_distribution = self._sample(
                    paths_to_sample_from, new_path_probs)
            new_beams.append(new_particles)
            new_distributions.append(new_distribution)