def predict_turn(self, utterance_final_state, input_hidden_states, max_generation_length, gold_query=None, snippets=None, input_sequence=None, feed_gold_tokens=False, training=False, first_utterance=True, gold_copy=None): """ Gets a prediction for a single turn -- calls decoder and updates loss, etc. TODO: this can probably be split into two methods, one that just predicts and another that computes the loss. """ predicted_sequence = [] fed_sequence = [] loss = None token_accuracy = 0. if feed_gold_tokens: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, gold_sequence=gold_query, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, controller=self.controller, first_utterance=first_utterance, gold_copy=gold_copy) all_scores = [step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions ] # Compute the loss if not pick_loss: loss = du.compute_loss(gold_query, all_scores, all_alignments, get_token_indices) else: loss = du.compute_loss(gold_copy[1:], all_scores, all_alignments, get_token_indices) if pick_loss: loss += pick_loss if not loss: loss = dy.zeros((1, 1)) if not training: predicted_sequence = du.get_seq_from_scores( all_scores, all_alignments) token_accuracy = du.per_token_accuracy(gold_query, predicted_sequence) fed_sequence = gold_query else: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, first_utterance=first_utterance) predicted_sequence = decoder_results.sequence fed_sequence = predicted_sequence # fed_sequence contains EOS, which we don't need when encoding snippets. # also ignore the first state, as it contains the BEG encoding. decoder_states = [ pred.decoder_state for pred in decoder_results.predictions ] if pick_loss: fed_sequence = fed_sequence[1:] for token, state in zip(fed_sequence[:-1], decoder_states[1:]): if snippet_handler.is_snippet(token): snippet_length = 0 for snippet in snippets: if snippet.name == token: snippet_length = len(snippet.sequence) break assert snippet_length > 0 decoder_states.extend([state for _ in range(snippet_length)]) else: decoder_states.append(state) return (predicted_sequence, loss, token_accuracy, decoder_states, decoder_results)
def evaluate_utterance_sample(sample, model, max_generation_length, name="", gold_forcing=False, metrics=None, total_num=-1, database_username="", database_password="", database_timeout=0, write_results=False): assert metrics if total_num < 0: total_num = len(sample) metrics_sums = {} for metric in metrics: metrics_sums[metric] = 0. predictions_file = open(name + "_predictions.json", "w") print("Predicting with filename " + str(name) + "_predictions.json") progbar = get_progressbar(name, len(sample)) progbar.start() predictions = [] for i, item in enumerate(sample): results, loss, predicted_seq = model.eval_step( item, max_generation_length, feed_gold_query=gold_forcing) loss = loss / len(item.gold_query()) predictions.append(predicted_seq) flat_sequence = item.flatten_sequence(predicted_seq) token_accuracy = du.per_token_accuracy(item.gold_query(), predicted_seq) if write_results: metrics_handler.write_prediction( predictions_file, identifier=item.interaction.identifier, input_seq=item.input_sequence(), probability=0, prediction=predicted_seq, flat_prediction=flat_sequence, gold_query=item.gold_query(), flat_gold_queries=item.original_gold_queries(), gold_tables=item.gold_tables(), index_in_interaction=item.utterance_index, database_username=database_username, database_password=database_password, database_timeout=database_timeout) update_sums(metrics, metrics_sums, predicted_seq, flat_sequence, item.gold_query(), item.original_gold_queries()[0], gold_forcing, loss, token_accuracy, database_username=database_username, database_password=database_password, database_timeout=database_timeout, gold_table=item.gold_tables()[0]) progbar.update(i) progbar.finish() predictions_file.close() return construct_averages(metrics_sums, total_num), None
def predict(model, utterances, prev_query=None, snippets=None, gold_seq=None, dropout_amount=0., loss_only=False, beam_size=1.): """ Predicts a SQL query given an utterance and other various inputs. Inputs: model (Seq2SeqModel): The model to use to predict. utterances (list of list of str): The utterances to predict for. prev_query (list of str, optional): The previously generated query. snippets (list of Snippet. optional): The snippets available for prediction. all_snippets (list of Snippet, optional): All snippets so far in the interaction. gold_seq (list of str, optional): The gold sequence. dropout_amount (float, optional): How much dropout to apply during predictino. loss_only (bool, optional): Whether to only return the loss. beam_size (float, optional): How many items to include in the beam during prediction. """ assert len(prev_query) == 0 or model.use_snippets assert len(snippets) == 0 or model.use_snippets assert not loss_only or len(gold_seq) > 0 (enc_state, enc_outputs), input_seq = model.encode_input_sequences( utterances, dropout_amount) embedded_snippets = [] if snippets: embedded_snippets = model.encode_snippets( prev_query, snippets, dropout_amount=dropout_amount) assert len(embedded_snippets) == len(snippets) if gold_seq: item = model.decode( enc_state, enc_outputs, input_seq, snippets=embedded_snippets if model.use_snippets else [], gold_seq=gold_seq, dropout_amount=dropout_amount)[0] scores = item.scores scores_by_timestep = [score[0] for score in scores] score_maps_by_timestep = [score[1] for score in scores] assert scores_by_timestep[0].dim()[0][0] == len( score_maps_by_timestep[0]) assert len(score_maps_by_timestep[0]) >= len( model.output_vocab) + len(snippets) loss = du.compute_loss(gold_seq, scores_by_timestep, score_maps_by_timestep, gold_tok_to_id, noise=0.00000000001) if loss_only: return loss sequence = du.get_seq_from_scores(scores_by_timestep, score_maps_by_timestep) else: item = model.decode( enc_state, enc_outputs, input_seq, snippets=embedded_snippets if model.use_snippets else [], beam_size=beam_size)[0] scalar_loss = 0 sequence = item.sequence token_acc = 0 if gold_seq: token_acc = du.per_token_accuracy(gold_seq, sequence) return sequence, scalar_loss, token_acc, item.probability
def evaluate_utterance_sample(sample, model, max_generation_length, name="", gold_forcing=False, metrics=None, total_num=-1, database_username="", database_password="", database_timeout=0, write_results=False): """Evaluates a sample of utterance examples. Inputs: sample (list of Utterance): Examples to evaluate. model (ATISModel): Model to predict with. max_generation_length (int): Maximum length to generate. name (str): Name to log with. gold_forcing (bool): Whether to force the gold tokens during decoding. metrics (list of Metric): Metrics to evaluate with. total_num (int): Number to divide by when reporting results. database_username (str): Username to use for executing queries. database_password (str): Password to use when executing queries. database_timeout (float): Timeout on queries when executing. write_results (bool): Whether to write the results to a file. """ assert metrics if total_num < 0: total_num = len(sample) metrics_sums = {} for metric in metrics: metrics_sums[metric] = 0. predictions_file = open(name + "_predictions.json", "w") print("Predicting with filename " + str(name) + "_predictions.json") progbar = get_progressbar(name, len(sample)) progbar.start() predictions = [] for i, item in enumerate(sample): _, loss, predicted_seq = model.eval_step(item, max_generation_length, feed_gold_query=gold_forcing) loss = loss / len(item.gold_query()) predictions.append(predicted_seq) flat_sequence = item.flatten_sequence(predicted_seq) token_accuracy = du.per_token_accuracy(item.gold_query(), predicted_seq) if write_results: metrics_handler.write_prediction( predictions_file, identifier=item.interaction.identifier, input_seq=item.input_sequence(), probability=0, prediction=predicted_seq, flat_prediction=flat_sequence, gold_query=item.gold_query(), flat_gold_queries=item.original_gold_queries(), gold_tables=item.gold_tables(), index_in_interaction=item.utterance_index, database_username=database_username, database_password=database_password, database_timeout=database_timeout) update_sums(metrics, metrics_sums, predicted_seq, flat_sequence, item.gold_query(), item.original_gold_queries()[0], gold_forcing, loss, token_accuracy, database_username=database_username, database_password=database_password, database_timeout=database_timeout, gold_table=item.gold_tables()[0]) progbar.update(i) progbar.finish() predictions_file.close() return construct_averages(metrics_sums, total_num), None