def train_step(self, examples): # sample a beam of logical forms for each example beams = self.predictions(examples, train=True) all_cases = [] # a list of ParseCases to give to ParseModel all_case_weights = [] # the weights associated with the cases for example, paths in zip(examples, beams): case_weights = self._case_weighter(paths, example) case_weights = flatten(case_weights) cases = flatten(paths) assert len(case_weights) == sum(len(p) for p in paths) all_cases.extend(cases) all_case_weights.extend(case_weights) # for efficiency, prune cases with weight 0 cases_to_reinforce = [] weights_to_reinforce = [] for case, weight in zip(all_cases, all_case_weights): if weight != 0: cases_to_reinforce.append(case) weights_to_reinforce.append(weight) # update value function vf_examples = [] for example, paths in zip(examples, beams): vf_examples.extend(ValueFunctionExample.examples_from_paths(paths, example)) self._value_function.train_step(vf_examples) # update parse model self._parse_model.train_step( cases_to_reinforce, weights_to_reinforce, caching=False)
def train_step(self, examples): # sample a beam of logical forms for each example beams = self.predictions(examples, train=True) all_cases = [] # a list of ParseCases to give to ParseModel all_case_weights = [] # the weights associated with the cases for example, paths in zip(examples, beams): case_weights = self._case_weighter(paths, example) case_weights = flatten(case_weights) cases = flatten(paths) assert len(case_weights) == sum(len(p) for p in paths) all_cases.extend(cases) all_case_weights.extend(case_weights) # for efficiency, prune cases with weight 0 cases_to_reinforce = [] weights_to_reinforce = [] for case, weight in zip(all_cases, all_case_weights): if weight != 0: cases_to_reinforce.append(case) weights_to_reinforce.append(weight) # update value function vf_examples = [] for example, paths in zip(examples, beams): vf_examples.extend( ValueFunctionExample.examples_from_paths(paths, example)) self._value_function.train_step(vf_examples) # update parse model self._parse_model.train_step(cases_to_reinforce, weights_to_reinforce, caching=False)
def _query_embeds(self, states, query_entries): """Given a batch of states, embed the keys and values of each state's query. Args: states (list[MiniWoBState]) Returns: entry_embeds (SequenceBatch): batch x num_keys x (2 * embed_dim) the keys and values concatenated """ fields_batch = [state.fields for state in states] # list[list[list[unicode]]] (batch x num_keys x key length) values_batch = [[word_tokenize(value) for value in fields.values] for fields in fields_batch] keys_batch = [[word_tokenize(key) for key in fields.keys] for fields in fields_batch] # Pad batch_size = len(fields_batch) max_num_fields = max(len(values) for values in values_batch) max_num_fields = max(max_num_fields, 1) # Ensure non-empty mask = torch.ones(batch_size, max_num_fields) assert len(keys_batch) == len(values_batch) == len(mask) for keys, values, submask in zip(keys_batch, values_batch, mask): assert len(keys) == len(values) if len(keys) < max_num_fields: submask[len(keys):] = 0. keys.extend( [[UtteranceVocab.PAD] for _ in xrange( max_num_fields - len(keys))]) values.extend( [[UtteranceVocab.PAD] for _ in xrange( max_num_fields - len(values))]) # Flatten to list[list[unicode]] (batch * num_keys) x key length keys_batch = flatten(keys_batch) values_batch = flatten(values_batch) # Embed and mask (batch * num_keys) x embed_dim key_embeds, _ = self._utterance_embedder(keys_batch) key_embeds = key_embeds.view( batch_size, max_num_fields, self._utterance_embedder.embed_dim) value_embeds, _ = self._utterance_embedder(values_batch) value_embeds = value_embeds.view( batch_size, max_num_fields, self._utterance_embedder.embed_dim) key_embeds = SequenceBatch(key_embeds, GPUVariable(mask)) value_embeds = SequenceBatch(value_embeds, GPUVariable(mask)) entry_embed_values = torch.cat( [key_embeds.values, value_embeds.values], 2) entry_embeds = SequenceBatch(entry_embed_values, key_embeds.mask) return entry_embeds
def __init__(self, hard_copy_vocab, source_tokens, token_lengths): """HardCopyDynamicVocab. NOTE: HardCopyDynamicVocab is blind to casing. Args: hard_copy_vocab (HardCopyVocab) source_tokens (list[list[unicode]]) -- outer list is over different input channels, inner list is a sentence """ num_channels = len(source_tokens) tok_cum = [0]+np.cumsum(token_lengths).tolist() tokens_by_channel = [hard_copy_vocab.copy_tokens[tok_cum[i]:tok_cum[i+1]] for i in range(num_channels)] self.tokens_by_channel = tokens_by_channel flatten = lambda l: [item for sublist in l for item in sublist] copy_words = [sentence_tokens + [hard_copy_vocab.STOP] for sentence_tokens in source_tokens] copy_pairs = flatten([zip(copy_words[i], tokens_by_channel[i]) for i in range(num_channels)]) # pair each copy word with a copy token # note that zip only zips up to the shorter of the two (desired behavior) #old snippet below #copy_words = [self.get_copy_words(sentence_tokens, hard_copy_vocab) for sentence_tokens in source_tokens] #copy_pairs = zip(copy_words, hard_copy_vocab.copy_tokens) self.word_to_copy_token = {w: t for w, t in copy_pairs[::-1]} self.copy_token_to_word = {t: w for w, t in copy_pairs[::-1]} # map from word to index self.base_vocab = hard_copy_vocab
def __init__(self, whitelists, vocab, word_to_forms, always_allowed=None): """Modify extension probs such that only words appearing in the whitelist (or have a variant in the whitelist) are allowed. Args: whitelists (list[list[unicode]]): a batch of whitelists, one per example. Each whitelist is a list of permitted words. vocab (HardCopyVocab) word_to_forms (Callable[unicode, list[unicode]): a function mapping words to forms always_allowed (list[unicode]): a list of words that is allowed in any example """ # importantly, vocab should NOT be a HardCopyDynamicVocab assert isinstance(vocab, HardCopyVocab) if always_allowed is None: # always allow STOP, copy tokens, and various stop words always_allowed = [] always_allowed.append(vocab.STOP) always_allowed.extend(vocab.copy_tokens) always_allowed.extend(STOPWORDS) # all lower case # so far, these should all be unique always_allowed_indices = [vocab.word2index(w) for w in always_allowed if w in vocab] # precompute actual allowed words all_whitelist_words = set(flatten(whitelists)) word_to_form_indices = self._word_to_form_indices(all_whitelist_words, vocab, word_to_forms) allowed_indices = [] for whitelist in whitelists: indices = self._whitelist_to_indices(whitelist, word_to_form_indices, always_allowed_indices) allowed_indices.append(indices) # Returned indices may contain some duplicates. self.allowed_indices = allowed_indices
def _get_neighbor_indices(self, dom_elements, is_neighbor): """Compute neighbor indices. Args: dom_elements (list[DOMElement]): may include PAD elements is_neighbor (Callable: DOMElement x DOMElement --> bool): True if two DOM elements are neighbors of each other, otherwise False Returns: SequenceBatch: of shape (total_dom_elems, max_neighbors) """ dom_element_ids = [id(e) for e in flatten(dom_elements)] dom_element_ids_set = set(dom_element_ids) vocab = SuperSimpleVocab(dom_element_ids) neighbors_batch = [] for dom_batch in dom_elements: for dom_elem in dom_batch: # Optimization: no DOM PAD has neighbors if isinstance(dom_elem, DOMElementPAD): neighbors = [] else: neighbors = [] for neighbor in dom_batch: if is_neighbor(dom_elem, neighbor): neighbors.append(id(neighbor)) neighbors_batch.append(neighbors) neighbor_indices = SequenceBatch.from_sequences(neighbors_batch, vocab, min_seq_length=1) return neighbor_indices
def decoder_inputs_and_outputs(target_words, base_vocab): """Convert a sequence of tokens into a decoder input seq and output seq. Args: target_words (list[unicode]) base_vocab (Vocab) Returns: input_words (list[unicode]) output_words (list[unicode]) """ # prepend with <start> token input_words = [base_vocab.START] + flatten(target_words) # append with <stop> token output_words = flatten(target_words) + [base_vocab.STOP] return input_words, output_words
def _edit_batch(self, examples, max_seq_length, beam_size, constrain_vocab): # should only run in evaluation mode assert not self.training input_words, output_words = self._batch_editor_examples(examples) base_vocab = self.base_vocab dynamic_vocabs = self._compute_dynamic_vocabs(input_words, base_vocab) dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(self.base_source_token_embedder, dynamic_vocabs, base_vocab) encoder_input = self.encoder.preprocess(input_words, output_words, dynamic_token_embedder, volatile=True) encoder_output, _ = self.encoder(encoder_input) extension_probs_modifiers = [] if constrain_vocab: whitelists = [flatten(ex.input_words) for ex in examples] # will contain duplicates, that's ok vocab_constrainer = LexicalWhitelister(whitelists, self.base_vocab, word_to_forms) extension_probs_modifiers.append(vocab_constrainer) beams, decoder_traces = self.test_decoder_beam.decode(examples, encoder_output, beam_size=beam_size, max_seq_length=max_seq_length, extension_probs_modifiers=extension_probs_modifiers ) # replace copy tokens in predictions with actual words, modifying beams in-place for beam, dyna_vocab in izip(beams, dynamic_vocabs): copy_to_word = dyna_vocab.copy_token_to_word for i, seq in enumerate(beam): beam[i] = [copy_to_word.get(w, w) for w in seq] return beams, [EditTrace(ex, d_trace.beam_traces[-1], dyna_vocab) for ex, d_trace, dyna_vocab in izip(examples, decoder_traces, dynamic_vocabs)]
def update_from_demonstrations(self, demonstrations, take_grad_step): """Calculates the cross-entropy loss from a batch of demonstrations. Args: demonstrations (EpisodeGraph) Returns: loss (Variable[FloatTensor]) take_grad_step (Callable): takes a loss Variable and takes a gradient step on the loss """ experiences = flatten(episode_graph.to_experiences() for episode_graph in demonstrations) scored_experiences = self._score_experiences(experiences) loss = -sum(exp.log_prob for exp in scored_experiences) / len(demonstrations) self._clear_cache() take_grad_step(loss)
def _score_episodes(self, episodes): """Score all the experiences in each of the episodes. Args: episodes (list[Episodes]) Returns: scored_episodes (list[ScoredEpisode]) """ # score all experiences in a batch experiences = flatten(episodes) scored_experiences = self._score_experiences(experiences) # convert experiences back into episodes scored_episodes = [] scored_experiences_reversed = list(reversed(scored_experiences)) for ep in episodes: scored_ep = Episode() for _ in range(len(ep)): scored_ep.append(scored_experiences_reversed.pop()) scored_episodes.append(scored_ep) return scored_episodes
def beam_duplicate(self, beam_size): duplicated_dynamic_vocabs = flatten([dyna_vocab] * beam_size for dyna_vocab in self.dynamic_vocabs) return DynamicMultiVocabTokenEmbedder(self.base_embedder, duplicated_dynamic_vocabs, self.base_vocab)
def forward(self, dom_elems, base_dom_embeds): """Embeds a batch of DOMElement sequences by mixing base embeddings on notions of neighbors. Args: dom_elems (list[list[DOMElement]]): a batch of DOMElement sequences to embed. All sequences must be padded to have the same number of DOM elements. base_dom_embeds (Variable[FloatTensor]): batch_size, num_dom_elems, base_dom_embed_dim Returns: dom_embeds (Variable[FloatTensor]): of shape (batch_size, num_dom_elems, embed_dim) """ batch_size, num_dom_elems, embed_dim = base_dom_embeds.size() # flatten, for easier processing base_dom_embeds_flat = base_dom_embeds.view(batch_size * num_dom_elems, embed_dim) # list of length: batch_size * num_dom_elems dom_elems_flat = flatten(dom_elems) assert len(dom_elems_flat) == batch_size * num_dom_elems # DOM neighbors whose LCA goes from depth 3 to 6 (root is depth 1) dom_neighbor_embeds = [] for k in xrange(self._lca_depth_start, self._lca_depth_end): is_neighbor_fn = lambda elem1, elem2: (N.is_depth_k_lca_neighbor( elem1, elem2, k, self._lca_cache) and N.is_text_neighbor( elem1, elem2)) dom_neighbor_indices = self._get_neighbor_indices( dom_elems, is_neighbor_fn) # TODO: reduce_max # (batch_size * num_dom_elems, embed_dim) neighbor_embedding = SequenceBatch.reduce_sum( SequenceBatch.embed(dom_neighbor_indices, base_dom_embeds_flat)) # TODO: Batch these projections? For performance projected_neighbor_embedding = self._dom_neighbor_projection( neighbor_embedding) dom_neighbor_embeds.append(projected_neighbor_embedding) # (batch_size * num_dom_elems, lca_range * (base_embed_dim / # lca_range)) dom_neighbor_embeds = torch.cat(dom_neighbor_embeds, 1) # SequenceBatch of shape (batch_size * num_dom_elems, max_neighbors) pixel_neighbor_indices = self._get_neighbor_indices( dom_elems, lambda elem1, elem2: (N.is_pixel_neighbor(elem1, elem2) and N.is_text_neighbor( elem1, elem2))) # SequenceBatch of shape # (batch_size * num_dom_elems, max_neighbors, embed_dim) pixel_neighbor_embeds = SequenceBatch.embed(pixel_neighbor_indices, base_dom_embeds_flat) # TODO(kelvin): switch to reduce_max # (batch_size * num_dom_elems, embed_dim) pixel_neighbor_embeds_flat = SequenceBatch.reduce_mean( pixel_neighbor_embeds, allow_empty=True) dom_embeds_flat = torch.cat([ base_dom_embeds_flat, pixel_neighbor_embeds_flat, dom_neighbor_embeds ], 1) dom_embeds = dom_embeds_flat.view(batch_size, num_dom_elems, self.embed_dim) return dom_embeds
def forward(self, dom_elem): """Embeds a batch of DOMElements. Args: dom_elem (list[list[DOMElement]]): batch of list of DOM. Each batch must already be padded to have the same number of DOM elements. Returns: Variable(FloatTensor): batch x num_dom_elems x embed_dim """ # Check that the batches are rectangular for dom_list in dom_elem: assert len(dom_list) == len(dom_elem[0]) num_dom_elems = len(dom_elem[0]) dom_elem = flatten(dom_elem) # (batch * max_dom_num) x lstm_dim text_embeddings = [] for batch in as_batches(dom_elem, 100): final_states, combined_states = self._utterance_embedder( [word_tokenize(dom.text) for dom in batch]) text_embeddings.append(final_states) text_embeddings = torch.cat(text_embeddings, 0) # (batch * max_dom_num) x tag_embed_dim tag_embeddings = self._tag_embedder.embed_tokens( [dom.tag for dom in dom_elem]) value_embeddings = self._value_embedder.embed_tokens( [bool(dom.value) for dom in dom_elem]) tampered_embeddings = self._tampered_embedder.embed_tokens( [dom.tampered for dom in dom_elem]) class_embeddings = self._classes_embedder.embed_tokens( [dom.classes for dom in dom_elem]) # (batch * max_dom_num) x 4 fg_colors = [ GPUVariable(torch.FloatTensor(elem.fg_color)) for elem in dom_elem ] fg_colors = torch.stack(fg_colors) bg_colors = [ GPUVariable(torch.FloatTensor(elem.bg_color)) for elem in dom_elem ] bg_colors = torch.stack(bg_colors) # (batch * max_dom_num) x 2 coords = [ GPUVariable( torch.FloatTensor((float(elem.left) / positions.IMAGE_COLS, float(elem.top) / positions.IMAGE_ROWS))) for elem in dom_elem ] coords = torch.stack(coords) # (batch * max_dom_num) * dom_embed_dim dom_embeddings = torch.cat( (text_embeddings, tag_embeddings, value_embeddings, tampered_embeddings, class_embeddings, coords, fg_colors, bg_colors), dim=1) # batch x max_dom_num x dom_embed_dim return dom_embeddings.view(-1, num_dom_elems, self.embed_dim)
def advance(self, terminated, beams, empirical_distributions): """Advance a batch of beams. Args: terminated (list[set(ParsePath)]): a batch of all the terminated paths found so far for each beam. beams (list[list[ParsePath]]): a batch of beams. All paths on all beams have the same length (all should be unterminated) empirical_distributions (list[list[float]]): a batch of distributions over the corresponding beams. Returns: list[set[ParsePath]]: a batch of terminated beams (in the same order as the input beams) list[list[ParsePath]]: a batch of new beams all extended by one time step list[list[float]]: the new empirical distributions over these particles """ # nothing on the beams should be terminated # terminated paths should be in the terminated set for beam in beams: for path in beam: assert not path.terminated path_extensions = [[path.extend() for path in beam] for beam in beams] # for exploration, use a parser which pretends like every utterance # is the first utterance it is seeing ignore_previous_utterances = \ self._config.independent_utterance_exploration # Use the ParseModel to score self._decoder.parse_model.score(flatten(path_extensions), ignore_previous_utterances, self._decoder.caching) new_beams = [] new_distributions = [] gamma = self._config.exploration_gamma for terminated_set, cases, distribution in zip( terminated, path_extensions, empirical_distributions): new_path_log_probs = [] paths_to_sample_from = [] for case, path_prob in zip(cases, distribution): for continuation in case.valid_continuations( self._decoder.path_checker): # Add all the terminated paths if continuation.terminated: terminated_set.add(continuation) else: # Sample from unterminated paths new_path_log_probs.append( gamma * continuation[-1].log_prob + np.log(path_prob)) paths_to_sample_from.append(continuation) if len(paths_to_sample_from) == 0: new_beams.append([]) new_distributions.append([]) continue new_path_probs = softmax(new_path_log_probs) new_particles, new_distribution = self._sample( paths_to_sample_from, new_path_probs) new_beams.append(new_particles) new_distributions.append(new_distribution)