コード例 #1
0
ファイル: interaction_model.py プロジェクト: lanlanabcd/atis
    def predict_with_gold_queries(self,
                                  interaction,
                                  max_generation_length,
                                  feed_gold_query=False):
        """ Predicts SQL queries for an interaction.

        Inputs:
            interaction (Interaction): Interaction to predict for.
            feed_gold_query (bool): Whether or not to feed the gold token to the
                generation step.
        """
        assert self.params.discourse_level_lstm

        dy.renew_cg()

        predictions = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        decoder_states = []

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        for utterance in interaction.gold_utterances():
            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Encode the utterance, and update the discourse-level states
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence,
                lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]),
                dropout_amount=self.dropout)

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states,
                self.dropout)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(
                        available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query,
                                                     available_snippets)

            prediction = self.predict_turn(final_utterance_state,
                                           utterance_states,
                                           max_generation_length,
                                           gold_query=utterance.gold_query(),
                                           snippets=snippets,
                                           input_sequence=flat_sequence,
                                           feed_gold_tokens=feed_gold_query)
            decoder_states = prediction[3]
            predictions.append(prediction)

        return predictions
コード例 #2
0
    def interactive_prediction(self, anonymizer):
        dy.renew_cg()

        snippet_bank = []
        anonymization_dictionary = {}
        previous_query = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        discourse_state, discourse_lstm_states = self._initialize_discourse_states()

        utterance = "show me flights from new york to boston"# input("> ")
        while utterance.lower() not in END_OF_INTERACTION:

            # First, need to normalize the utterance and get an anonymization dictionary.
            tokenized_sequence = tokenizers.nl_tokenize(utterance)

            available_snippets = [snippet for snippet in snippet_bank if snippet.index <= 1]

            sequence_to_use = tokenized_sequence

            #TODO: implement date normalization

            if self.params.anonymize:
                sequence_to_use = anonymizer.anonymize(tokenized_sequence,
                                                       anonymization_dictionary,
                                                       ANON_INPUT_KEY,
                                                       add_new_anon_toks=True)

            # Now we encode the sequence
            final_utterance_state, utterance_states = self.utterance_encoder(
                sequence_to_use,
                lambda token: dy.concatenate([self.input_embedder(token), discourse_state]))

            input_hidden_states.extend(utterance_states)
            input_sequences.append(sequence_to_use)

            # Now update the discourse state
            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states)

            # Add positional embeddings
            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            # Encode the snippets
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(
                        previous_query, available_snippets)


            # Predict a result
            results = self.predict_turn(final_utterance_state,
                                        utterance_states,
                                        self.params.eval_maximum_sql_length,
                                        input_sequence=flat_sequence,
                                        snippets=snippets)

            # Get the sequence, and show the de-anonymized and flattened versions
            predicted_sequence = results[0]

            anonymized_sequence = utterance.remove_snippets(predicted_sequence)[:-1]
            print(" ".join(anonymized_sequence))
            flat_sequence = utterance.flatten_sequence(predicted_sequence)
            print(" ".join(flat_sequence))

            # Execute the query and show the results

            # Update the available snippets, etc. 


        utterance = input("> ")
コード例 #3
0
ファイル: interaction_model.py プロジェクト: lanlanabcd/atis
    def train(self,
              interaction,
              max_generation_length,
              snippet_alignment_probability=1.):
        """ Trains the interaction-level model on a single interaction.

        Inputs:
            interaction (Interaction): The interaction to train on.
            learning_rate (float): Learning rate to use.
            snippet_keep_age (int): Age of oldest snippets to use.
            snippet_alignment_probability (float): The probability that a snippet will
                be used in constructing the gold sequence.
        """
        assert self.params.discourse_level_lstm

        dy.renew_cg()

        losses = []
        total_gold_tokens = 0

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        decoder_states = []

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        # 指示是否开始新的轮次
        new_turn = True

        for utterance_index, utterance in enumerate(
                interaction.gold_utterances()):
            if interaction.identifier in LIMITED_INTERACTIONS \
                    and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]:
                break

            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Get the gold query: reconstruct if the alignment probability
            # is less than one
            if snippet_alignment_probability < 1.:
                gold_query = sql_util.add_snippets_to_query(
                    available_snippets,
                    utterance.contained_entities(),
                    utterance.anonymized_gold_query(),
                    prob_align=snippet_alignment_probability) + [
                        vocab.EOS_TOK
                    ]
            else:
                gold_query = utterance.gold_query()
                if self.params.copy:
                    gold_copy = utterance.gold_copy()
                else:
                    gold_copy = None

            # Encode the utterance, and update the discourse-level states
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence,
                lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]),
                dropout_amount=self.dropout)

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states,
                self.dropout)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(
                        available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query,
                                                     available_snippets)

            if len(gold_query) <= max_generation_length \
                    and len(previous_query) <= max_generation_length:
                #print("=====")
                print(utterance_index)
                prediction = self.predict_turn(final_utterance_state,
                                               utterance_states,
                                               max_generation_length,
                                               gold_query=gold_query,
                                               snippets=snippets,
                                               input_sequence=flat_sequence,
                                               feed_gold_tokens=True,
                                               training=True,
                                               first_utterance=new_turn,
                                               gold_copy=gold_copy)
                new_turn = False
                loss = prediction[1]
                decoder_states = prediction[3]
                total_gold_tokens += len(gold_query)
                losses.append(loss)
            else:
                # Break if previous decoder snippet encoding -- because the previous
                # sequence was too long to run the decoder.
                if self.params.previous_decoder_snippet_encoding:
                    break
                continue

        if losses:
            average_loss = dy.esum(losses) / total_gold_tokens

            # Renormalize so the effect is normalized by the batch size.
            normalized_loss = average_loss
            if self.params.reweight_batch:
                normalized_loss = len(losses) * average_loss / \
                    float(self.params.batch_size)
            normalized_loss.forward()
            normalized_loss.backward()
            self.trainer.update()
            loss_scalar = normalized_loss.value()
        else:
            loss_scalar = 0.

        return loss_scalar
コード例 #4
0
    def predict_with_predicted_queries(self, interaction, max_generation_length, syntax_restrict=True):
        """ Predicts an interaction, using the predicted queries to get snippets."""
        assert self.params.discourse_level_lstm


        dy.renew_cg()

        predictions = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        discourse_state, discourse_lstm_states = self._initialize_discourse_states()

        interaction.start_interaction()
        while not interaction.done():
            # TODO: snippet keep age here
            utterance = interaction.next_utterance()

            # TODO: make sure these are all of the correct age
            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            input_sequence = utterance.input_sequence()
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence,
                lambda token: dy.concatenate([self.input_embedder(token), discourse_state]))

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(
                        previous_query, available_snippets)

            results = self.predict_turn(final_utterance_state,
                                        utterance_states,
                                        max_generation_length,
                                        input_sequence=flat_sequence,
                                        snippets=snippets)

            predicted_sequence = results[0]
            predictions.append(results)

            # Update things necessary for using predicted queries
            anonymized_sequence = utterance.remove_snippets(predicted_sequence)[
                :-1]
            flat_sequence = utterance.flatten_sequence(predicted_sequence)

            if not syntax_restrict or sql_util.executable(flat_sequence,
                                   username=self.params.database_username,
                                   password=self.params.database_password,
                                   timeout=self.params.database_timeout):
                utterance.set_pred_query(
                    interaction.remove_snippets(predicted_sequence))
                interaction.add_utterance(
                    utterance,
                    anonymized_sequence,
                    previous_snippets=utterance.snippets())

            else:
                utterance.set_predicted_query(utterance.previous_query())
                interaction.add_utterance(
                    utterance,
                    utterance.previous_query(),
                    previous_snippets=utterance.snippets())

        return predictions