Esempio n. 1
0
    def eval_step(self, example, feed_gold_query=False):
        """Evaluates the model on a specific example.

        Inputs:
            example (utterance example): Example to feed.
            feed_gold_query (bool): Whether or not to token-feed the gold query.
        """
        dy.renew_cg()
        # First, encode the input sequences.
        input_sequences = example.histories(
            self.params.maximum_utterances - 1) + [example.input_sequence()]
        final_state, utterance_hidden_states = self._encode_input_sequences(
            input_sequences)

        # Add positional embeddings if appropriate
        if self.params.state_positional_embeddings:
            utterance_hidden_states, flat_sequence = self._add_positional_embeddings(
                utterance_hidden_states, input_sequences)

        # Encode the snippets
        snippets = []
        if self.params.use_snippets:
            snippets = self._encode_snippets(example.previous_query(), snippets)

        # Decode
        flat_seq = []
        for sequence in input_sequences:
            flat_seq.extend(sequence)
        decoder_results = self.decoder(
            final_state,
            utterance_hidden_states,
            self.params.train_maximum_sql_length,
            snippets=snippets,
            gold_sequence=example.gold_query() if feed_gold_query else None,
            dropout_amount=self.dropout,
            input_sequence=flat_seq,
            controller=self.controller)

        all_scores = [
            step.scores for step in decoder_results.predictions]
        all_alignments = [
            step.aligned_tokens for step in decoder_results.predictions]
        loss = dy.zeros((1, 1))
        if feed_gold_query:
            loss = du.compute_loss(example.gold_query(),
                                   all_scores,
                                   all_alignments,
                                   get_token_indices)
        predicted_seq = du.get_seq_from_scores(all_scores, all_alignments)
        return decoder_results, loss, predicted_seq
Esempio n. 2
0
    def predict_turn(self,
                     utterance_final_state,
                     input_hidden_states,
                     max_generation_length,
                     gold_query=None,
                     snippets=None,
                     input_sequence=None,
                     feed_gold_tokens=False,
                     training=False,
                     first_utterance=True,
                     gold_copy=None):
        """ Gets a prediction for a single turn -- calls decoder and updates loss, etc.

        TODO:  this can probably be split into two methods, one that just predicts
            and another that computes the loss.
        """
        predicted_sequence = []
        fed_sequence = []
        loss = None
        token_accuracy = 0.

        if feed_gold_tokens:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                gold_sequence=gold_query,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                controller=self.controller,
                first_utterance=first_utterance,
                gold_copy=gold_copy)

            all_scores = [step.scores for step in decoder_results.predictions]
            all_alignments = [
                step.aligned_tokens for step in decoder_results.predictions
            ]
            # Compute the loss
            if not pick_loss:
                loss = du.compute_loss(gold_query, all_scores, all_alignments,
                                       get_token_indices)
            else:
                loss = du.compute_loss(gold_copy[1:], all_scores,
                                       all_alignments, get_token_indices)
            if pick_loss:
                loss += pick_loss
            if not loss:
                loss = dy.zeros((1, 1))

            if not training:
                predicted_sequence = du.get_seq_from_scores(
                    all_scores, all_alignments)

                token_accuracy = du.per_token_accuracy(gold_query,
                                                       predicted_sequence)

            fed_sequence = gold_query
        else:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                first_utterance=first_utterance)
            predicted_sequence = decoder_results.sequence

            fed_sequence = predicted_sequence

        # fed_sequence contains EOS, which we don't need when encoding snippets.
        # also ignore the first state, as it contains the BEG encoding.
        decoder_states = [
            pred.decoder_state for pred in decoder_results.predictions
        ]

        if pick_loss:
            fed_sequence = fed_sequence[1:]
        for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
            if snippet_handler.is_snippet(token):
                snippet_length = 0
                for snippet in snippets:
                    if snippet.name == token:
                        snippet_length = len(snippet.sequence)
                        break
                assert snippet_length > 0
                decoder_states.extend([state for _ in range(snippet_length)])
            else:
                decoder_states.append(state)

        return (predicted_sequence, loss, token_accuracy, decoder_states,
                decoder_results)
Esempio n. 3
0
    def train_step(self, batch):
        """Training step for a batch of examples.

        Input:
            batch (list of examples): Batch of examples used to update.
        """
        dy.renew_cg(autobatching=True)

        losses = []
        total_gold_tokens = 0

        batch.start()
        while not batch.done():
            # 每个batch是16个interactions,每个interaction是不定数目utterances的集合
            example = batch.next()

            # First, encode the input sequences.
            input_sequences = example.histories(
                self.params.maximum_utterances - 1) + [example.input_sequence()]
            final_state, utterance_hidden_states = self._encode_input_sequences(
                input_sequences)

            # Add positional embeddings if appropriate
            if self.params.state_positional_embeddings:
                utterance_hidden_states, flat_sequence = self._add_positional_embeddings(
                    utterance_hidden_states, input_sequences)

            # Encode the snippets
            snippets = []
            if self.params.use_snippets:
                snippets = self._encode_snippets(example.previous_query(), snippets)

            # Decode
            flat_seq = []
            for sequence in input_sequences:
                flat_seq.extend(sequence)
            decoder_results = self.decoder(
                final_state,
                utterance_hidden_states,
                self.params.train_maximum_sql_length,
                snippets=snippets,
                gold_sequence=example.gold_query(),
                dropout_amount=self.dropout,
                input_sequence=flat_seq,
                controller=self.controller)
            all_scores = [
                step.scores for step in decoder_results.predictions]
            all_alignments = [
                step.aligned_tokens for step in decoder_results.predictions]
            loss = du.compute_loss(example.gold_query(),
                                   all_scores,
                                   all_alignments,
                                   get_token_indices)
            losses.append(loss)
            total_gold_tokens += len(example.gold_query())

        average_loss = dy.esum(losses) / total_gold_tokens
        average_loss.forward()
        average_loss.backward()
        self.trainer.update()
        loss_scalar = average_loss.value()

        return loss_scalar
Esempio n. 4
0
def predict(model,
            utterances,
            prev_query=None,
            snippets=None,
            gold_seq=None,
            dropout_amount=0.,
            loss_only=False,
            beam_size=1.):
    """ Predicts a SQL query given an utterance and other various inputs.

    Inputs:
        model (Seq2SeqModel): The model to use to predict.
        utterances (list of list of str): The utterances to predict for.
        prev_query (list of str, optional): The previously generated query.
        snippets (list of Snippet. optional): The snippets available for prediction.
        all_snippets (list of Snippet, optional): All snippets so far in the interaction.
        gold_seq (list of str, optional): The gold sequence.
        dropout_amount (float, optional): How much dropout to apply during predictino.
        loss_only (bool, optional): Whether to only return the loss.
        beam_size (float, optional): How many items to include in the beam during prediction.
    """
    assert len(prev_query) == 0 or model.use_snippets
    assert len(snippets) == 0 or model.use_snippets
    assert not loss_only or len(gold_seq) > 0

    (enc_state, enc_outputs), input_seq = model.encode_input_sequences(
        utterances, dropout_amount)

    embedded_snippets = []
    if snippets:
        embedded_snippets = model.encode_snippets(
            prev_query, snippets, dropout_amount=dropout_amount)
        assert len(embedded_snippets) == len(snippets)

    if gold_seq:
        item = model.decode(
            enc_state,
            enc_outputs,
            input_seq,
            snippets=embedded_snippets if model.use_snippets else [],
            gold_seq=gold_seq,
            dropout_amount=dropout_amount)[0]
        scores = item.scores
        scores_by_timestep = [score[0] for score in scores]
        score_maps_by_timestep = [score[1] for score in scores]

        assert scores_by_timestep[0].dim()[0][0] == len(
            score_maps_by_timestep[0])
        assert len(score_maps_by_timestep[0]) >= len(
            model.output_vocab) + len(snippets)

        loss = du.compute_loss(gold_seq,
                               scores_by_timestep,
                               score_maps_by_timestep,
                               gold_tok_to_id,
                               noise=0.00000000001)

        if loss_only:
            return loss
        sequence = du.get_seq_from_scores(scores_by_timestep,
                                          score_maps_by_timestep)
    else:
        item = model.decode(
            enc_state,
            enc_outputs,
            input_seq,
            snippets=embedded_snippets if model.use_snippets else [],
            beam_size=beam_size)[0]
        scalar_loss = 0
        sequence = item.sequence

    token_acc = 0
    if gold_seq:
        token_acc = du.per_token_accuracy(gold_seq, sequence)

    return sequence, scalar_loss, token_acc, item.probability