def eval_step(self, example, feed_gold_query=False): """Evaluates the model on a specific example. Inputs: example (utterance example): Example to feed. feed_gold_query (bool): Whether or not to token-feed the gold query. """ dy.renew_cg() # First, encode the input sequences. input_sequences = example.histories( self.params.maximum_utterances - 1) + [example.input_sequence()] final_state, utterance_hidden_states = self._encode_input_sequences( input_sequences) # Add positional embeddings if appropriate if self.params.state_positional_embeddings: utterance_hidden_states, flat_sequence = self._add_positional_embeddings( utterance_hidden_states, input_sequences) # Encode the snippets snippets = [] if self.params.use_snippets: snippets = self._encode_snippets(example.previous_query(), snippets) # Decode flat_seq = [] for sequence in input_sequences: flat_seq.extend(sequence) decoder_results = self.decoder( final_state, utterance_hidden_states, self.params.train_maximum_sql_length, snippets=snippets, gold_sequence=example.gold_query() if feed_gold_query else None, dropout_amount=self.dropout, input_sequence=flat_seq, controller=self.controller) all_scores = [ step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions] loss = dy.zeros((1, 1)) if feed_gold_query: loss = du.compute_loss(example.gold_query(), all_scores, all_alignments, get_token_indices) predicted_seq = du.get_seq_from_scores(all_scores, all_alignments) return decoder_results, loss, predicted_seq
def predict_turn(self, utterance_final_state, input_hidden_states, max_generation_length, gold_query=None, snippets=None, input_sequence=None, feed_gold_tokens=False, training=False, first_utterance=True, gold_copy=None): """ Gets a prediction for a single turn -- calls decoder and updates loss, etc. TODO: this can probably be split into two methods, one that just predicts and another that computes the loss. """ predicted_sequence = [] fed_sequence = [] loss = None token_accuracy = 0. if feed_gold_tokens: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, gold_sequence=gold_query, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, controller=self.controller, first_utterance=first_utterance, gold_copy=gold_copy) all_scores = [step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions ] # Compute the loss if not pick_loss: loss = du.compute_loss(gold_query, all_scores, all_alignments, get_token_indices) else: loss = du.compute_loss(gold_copy[1:], all_scores, all_alignments, get_token_indices) if pick_loss: loss += pick_loss if not loss: loss = dy.zeros((1, 1)) if not training: predicted_sequence = du.get_seq_from_scores( all_scores, all_alignments) token_accuracy = du.per_token_accuracy(gold_query, predicted_sequence) fed_sequence = gold_query else: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, first_utterance=first_utterance) predicted_sequence = decoder_results.sequence fed_sequence = predicted_sequence # fed_sequence contains EOS, which we don't need when encoding snippets. # also ignore the first state, as it contains the BEG encoding. decoder_states = [ pred.decoder_state for pred in decoder_results.predictions ] if pick_loss: fed_sequence = fed_sequence[1:] for token, state in zip(fed_sequence[:-1], decoder_states[1:]): if snippet_handler.is_snippet(token): snippet_length = 0 for snippet in snippets: if snippet.name == token: snippet_length = len(snippet.sequence) break assert snippet_length > 0 decoder_states.extend([state for _ in range(snippet_length)]) else: decoder_states.append(state) return (predicted_sequence, loss, token_accuracy, decoder_states, decoder_results)
def train_step(self, batch): """Training step for a batch of examples. Input: batch (list of examples): Batch of examples used to update. """ dy.renew_cg(autobatching=True) losses = [] total_gold_tokens = 0 batch.start() while not batch.done(): # 每个batch是16个interactions,每个interaction是不定数目utterances的集合 example = batch.next() # First, encode the input sequences. input_sequences = example.histories( self.params.maximum_utterances - 1) + [example.input_sequence()] final_state, utterance_hidden_states = self._encode_input_sequences( input_sequences) # Add positional embeddings if appropriate if self.params.state_positional_embeddings: utterance_hidden_states, flat_sequence = self._add_positional_embeddings( utterance_hidden_states, input_sequences) # Encode the snippets snippets = [] if self.params.use_snippets: snippets = self._encode_snippets(example.previous_query(), snippets) # Decode flat_seq = [] for sequence in input_sequences: flat_seq.extend(sequence) decoder_results = self.decoder( final_state, utterance_hidden_states, self.params.train_maximum_sql_length, snippets=snippets, gold_sequence=example.gold_query(), dropout_amount=self.dropout, input_sequence=flat_seq, controller=self.controller) all_scores = [ step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions] loss = du.compute_loss(example.gold_query(), all_scores, all_alignments, get_token_indices) losses.append(loss) total_gold_tokens += len(example.gold_query()) average_loss = dy.esum(losses) / total_gold_tokens average_loss.forward() average_loss.backward() self.trainer.update() loss_scalar = average_loss.value() return loss_scalar
def predict(model, utterances, prev_query=None, snippets=None, gold_seq=None, dropout_amount=0., loss_only=False, beam_size=1.): """ Predicts a SQL query given an utterance and other various inputs. Inputs: model (Seq2SeqModel): The model to use to predict. utterances (list of list of str): The utterances to predict for. prev_query (list of str, optional): The previously generated query. snippets (list of Snippet. optional): The snippets available for prediction. all_snippets (list of Snippet, optional): All snippets so far in the interaction. gold_seq (list of str, optional): The gold sequence. dropout_amount (float, optional): How much dropout to apply during predictino. loss_only (bool, optional): Whether to only return the loss. beam_size (float, optional): How many items to include in the beam during prediction. """ assert len(prev_query) == 0 or model.use_snippets assert len(snippets) == 0 or model.use_snippets assert not loss_only or len(gold_seq) > 0 (enc_state, enc_outputs), input_seq = model.encode_input_sequences( utterances, dropout_amount) embedded_snippets = [] if snippets: embedded_snippets = model.encode_snippets( prev_query, snippets, dropout_amount=dropout_amount) assert len(embedded_snippets) == len(snippets) if gold_seq: item = model.decode( enc_state, enc_outputs, input_seq, snippets=embedded_snippets if model.use_snippets else [], gold_seq=gold_seq, dropout_amount=dropout_amount)[0] scores = item.scores scores_by_timestep = [score[0] for score in scores] score_maps_by_timestep = [score[1] for score in scores] assert scores_by_timestep[0].dim()[0][0] == len( score_maps_by_timestep[0]) assert len(score_maps_by_timestep[0]) >= len( model.output_vocab) + len(snippets) loss = du.compute_loss(gold_seq, scores_by_timestep, score_maps_by_timestep, gold_tok_to_id, noise=0.00000000001) if loss_only: return loss sequence = du.get_seq_from_scores(scores_by_timestep, score_maps_by_timestep) else: item = model.decode( enc_state, enc_outputs, input_seq, snippets=embedded_snippets if model.use_snippets else [], beam_size=beam_size)[0] scalar_loss = 0 sequence = item.sequence token_acc = 0 if gold_seq: token_acc = du.per_token_accuracy(gold_seq, sequence) return sequence, scalar_loss, token_acc, item.probability