Esempio n. 1
0
    def predict_turn(self,
                     utterance_final_state,
                     input_hidden_states,
                     max_generation_length,
                     gold_query=None,
                     snippets=None,
                     input_sequence=None,
                     feed_gold_tokens=False,
                     training=False,
                     first_utterance=True,
                     gold_copy=None):
        """ Gets a prediction for a single turn -- calls decoder and updates loss, etc.

        TODO:  this can probably be split into two methods, one that just predicts
            and another that computes the loss.
        """
        predicted_sequence = []
        fed_sequence = []
        loss = None
        token_accuracy = 0.

        if feed_gold_tokens:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                gold_sequence=gold_query,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                controller=self.controller,
                first_utterance=first_utterance,
                gold_copy=gold_copy)

            all_scores = [step.scores for step in decoder_results.predictions]
            all_alignments = [
                step.aligned_tokens for step in decoder_results.predictions
            ]
            # Compute the loss
            if not pick_loss:
                loss = du.compute_loss(gold_query, all_scores, all_alignments,
                                       get_token_indices)
            else:
                loss = du.compute_loss(gold_copy[1:], all_scores,
                                       all_alignments, get_token_indices)
            if pick_loss:
                loss += pick_loss
            if not loss:
                loss = dy.zeros((1, 1))

            if not training:
                predicted_sequence = du.get_seq_from_scores(
                    all_scores, all_alignments)

                token_accuracy = du.per_token_accuracy(gold_query,
                                                       predicted_sequence)

            fed_sequence = gold_query
        else:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                first_utterance=first_utterance)
            predicted_sequence = decoder_results.sequence

            fed_sequence = predicted_sequence

        # fed_sequence contains EOS, which we don't need when encoding snippets.
        # also ignore the first state, as it contains the BEG encoding.
        decoder_states = [
            pred.decoder_state for pred in decoder_results.predictions
        ]

        if pick_loss:
            fed_sequence = fed_sequence[1:]
        for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
            if snippet_handler.is_snippet(token):
                snippet_length = 0
                for snippet in snippets:
                    if snippet.name == token:
                        snippet_length = len(snippet.sequence)
                        break
                assert snippet_length > 0
                decoder_states.extend([state for _ in range(snippet_length)])
            else:
                decoder_states.append(state)

        return (predicted_sequence, loss, token_accuracy, decoder_states,
                decoder_results)
Esempio n. 2
0
def evaluate_utterance_sample(sample,
                              model,
                              max_generation_length,
                              name="",
                              gold_forcing=False,
                              metrics=None,
                              total_num=-1,
                              database_username="",
                              database_password="",
                              database_timeout=0,
                              write_results=False):
    assert metrics

    if total_num < 0:
        total_num = len(sample)

    metrics_sums = {}
    for metric in metrics:
        metrics_sums[metric] = 0.

    predictions_file = open(name + "_predictions.json", "w")
    print("Predicting with filename " + str(name) + "_predictions.json")
    progbar = get_progressbar(name, len(sample))
    progbar.start()

    predictions = []
    for i, item in enumerate(sample):
        results, loss, predicted_seq = model.eval_step(
            item, max_generation_length, feed_gold_query=gold_forcing)
        loss = loss / len(item.gold_query())
        predictions.append(predicted_seq)

        flat_sequence = item.flatten_sequence(predicted_seq)
        token_accuracy = du.per_token_accuracy(item.gold_query(),
                                               predicted_seq)

        if write_results:
            metrics_handler.write_prediction(
                predictions_file,
                identifier=item.interaction.identifier,
                input_seq=item.input_sequence(),
                probability=0,
                prediction=predicted_seq,
                flat_prediction=flat_sequence,
                gold_query=item.gold_query(),
                flat_gold_queries=item.original_gold_queries(),
                gold_tables=item.gold_tables(),
                index_in_interaction=item.utterance_index,
                database_username=database_username,
                database_password=database_password,
                database_timeout=database_timeout)

        update_sums(metrics,
                    metrics_sums,
                    predicted_seq,
                    flat_sequence,
                    item.gold_query(),
                    item.original_gold_queries()[0],
                    gold_forcing,
                    loss,
                    token_accuracy,
                    database_username=database_username,
                    database_password=database_password,
                    database_timeout=database_timeout,
                    gold_table=item.gold_tables()[0])

        progbar.update(i)

    progbar.finish()
    predictions_file.close()

    return construct_averages(metrics_sums, total_num), None
Esempio n. 3
0
def predict(model,
            utterances,
            prev_query=None,
            snippets=None,
            gold_seq=None,
            dropout_amount=0.,
            loss_only=False,
            beam_size=1.):
    """ Predicts a SQL query given an utterance and other various inputs.

    Inputs:
        model (Seq2SeqModel): The model to use to predict.
        utterances (list of list of str): The utterances to predict for.
        prev_query (list of str, optional): The previously generated query.
        snippets (list of Snippet. optional): The snippets available for prediction.
        all_snippets (list of Snippet, optional): All snippets so far in the interaction.
        gold_seq (list of str, optional): The gold sequence.
        dropout_amount (float, optional): How much dropout to apply during predictino.
        loss_only (bool, optional): Whether to only return the loss.
        beam_size (float, optional): How many items to include in the beam during prediction.
    """
    assert len(prev_query) == 0 or model.use_snippets
    assert len(snippets) == 0 or model.use_snippets
    assert not loss_only or len(gold_seq) > 0

    (enc_state, enc_outputs), input_seq = model.encode_input_sequences(
        utterances, dropout_amount)

    embedded_snippets = []
    if snippets:
        embedded_snippets = model.encode_snippets(
            prev_query, snippets, dropout_amount=dropout_amount)
        assert len(embedded_snippets) == len(snippets)

    if gold_seq:
        item = model.decode(
            enc_state,
            enc_outputs,
            input_seq,
            snippets=embedded_snippets if model.use_snippets else [],
            gold_seq=gold_seq,
            dropout_amount=dropout_amount)[0]
        scores = item.scores
        scores_by_timestep = [score[0] for score in scores]
        score_maps_by_timestep = [score[1] for score in scores]

        assert scores_by_timestep[0].dim()[0][0] == len(
            score_maps_by_timestep[0])
        assert len(score_maps_by_timestep[0]) >= len(
            model.output_vocab) + len(snippets)

        loss = du.compute_loss(gold_seq,
                               scores_by_timestep,
                               score_maps_by_timestep,
                               gold_tok_to_id,
                               noise=0.00000000001)

        if loss_only:
            return loss
        sequence = du.get_seq_from_scores(scores_by_timestep,
                                          score_maps_by_timestep)
    else:
        item = model.decode(
            enc_state,
            enc_outputs,
            input_seq,
            snippets=embedded_snippets if model.use_snippets else [],
            beam_size=beam_size)[0]
        scalar_loss = 0
        sequence = item.sequence

    token_acc = 0
    if gold_seq:
        token_acc = du.per_token_accuracy(gold_seq, sequence)

    return sequence, scalar_loss, token_acc, item.probability
Esempio n. 4
0
def evaluate_utterance_sample(sample,
                              model,
                              max_generation_length,
                              name="",
                              gold_forcing=False,
                              metrics=None,
                              total_num=-1,
                              database_username="",
                              database_password="",
                              database_timeout=0,
                              write_results=False):
    """Evaluates a sample of utterance examples.

    Inputs:
        sample (list of Utterance): Examples to evaluate.
        model (ATISModel): Model to predict with.
        max_generation_length (int): Maximum length to generate.
        name (str): Name to log with.
        gold_forcing (bool): Whether to force the gold tokens during decoding.
        metrics (list of Metric): Metrics to evaluate with.
        total_num (int): Number to divide by when reporting results.
        database_username (str): Username to use for executing queries.
        database_password (str): Password to use when executing queries.
        database_timeout (float): Timeout on queries when executing.
        write_results (bool): Whether to write the results to a file.
    """
    assert metrics

    if total_num < 0:
        total_num = len(sample)

    metrics_sums = {}
    for metric in metrics:
        metrics_sums[metric] = 0.

    predictions_file = open(name + "_predictions.json", "w")
    print("Predicting with filename " + str(name) + "_predictions.json")
    progbar = get_progressbar(name, len(sample))
    progbar.start()

    predictions = []
    for i, item in enumerate(sample):
        _, loss, predicted_seq = model.eval_step(item,
                                                 max_generation_length,
                                                 feed_gold_query=gold_forcing)
        loss = loss / len(item.gold_query())
        predictions.append(predicted_seq)

        flat_sequence = item.flatten_sequence(predicted_seq)
        token_accuracy = du.per_token_accuracy(item.gold_query(),
                                               predicted_seq)

        if write_results:
            metrics_handler.write_prediction(
                predictions_file,
                identifier=item.interaction.identifier,
                input_seq=item.input_sequence(),
                probability=0,
                prediction=predicted_seq,
                flat_prediction=flat_sequence,
                gold_query=item.gold_query(),
                flat_gold_queries=item.original_gold_queries(),
                gold_tables=item.gold_tables(),
                index_in_interaction=item.utterance_index,
                database_username=database_username,
                database_password=database_password,
                database_timeout=database_timeout)

        update_sums(metrics,
                    metrics_sums,
                    predicted_seq,
                    flat_sequence,
                    item.gold_query(),
                    item.original_gold_queries()[0],
                    gold_forcing,
                    loss,
                    token_accuracy,
                    database_username=database_username,
                    database_password=database_password,
                    database_timeout=database_timeout,
                    gold_table=item.gold_tables()[0])

        progbar.update(i)

    progbar.finish()
    predictions_file.close()

    return construct_averages(metrics_sums, total_num), None