Example #1
0
File: model.py Project: we1l1n/atis
    def _encode_with_discourse_lstm(self, utterances):
        """ Encodes the utterances using a discourse-level LSTM, instead of concatenating.

        Inputs:
            utterances (list of list of str): Utterances.
        """
        hidden_states = []

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        final_state = None
        for utterance in utterances:
            final_state, utterance_states = self.utterance_encoder(
                utterance,
                lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]),
                dropout_amount=self.dropout)

            hidden_states.extend(utterance_states)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_state, discourse_lstm_states, self.dropout)

        return final_state, hidden_states
Example #2
0
    def interactive_prediction(self, anonymizer):
        """Interactive prediction.

        Inputs:
            anonymizer (Anonymizer): Anonymizer to use for user's input.
        """
        dy.renew_cg()

        snippet_bank = []
        anonymization_dictionary = {}
        previous_query = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        utterance = "show me flights from new york to boston"  # input("> ")
        while utterance.lower() not in END_OF_INTERACTION:

            # First, need to normalize the utterance and get an anonymization
            # dictionary.
            tokenized_sequence = tokenizers.nl_tokenize(utterance)

            available_snippets = [
                snippet for snippet in snippet_bank if snippet.index <= 1
            ]

            sequence_to_use = tokenized_sequence

            # TODO: implement date normalization

            if self.params.anonymize:
                sequence_to_use = anonymizer.anonymize(
                    tokenized_sequence,
                    anonymization_dictionary,
                    ANON_INPUT_KEY,
                    add_new_anon_toks=True)

            # Now we encode the sequence
            final_utterance_state, utterance_states = self.utterance_encoder(
                sequence_to_use, lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]))

            input_hidden_states.extend(utterance_states)
            input_sequences.append(sequence_to_use)

            # Now update the discourse state
            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states)

            # Add positional embeddings
            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            # Encode the snippets
            if self.params.use_snippets:
                snippets = self._encode_snippets(previous_query,
                                                 available_snippets)

            # Predict a result
            results = self.predict_turn(final_utterance_state,
                                        utterance_states,
                                        self.params.eval_maximum_sql_length,
                                        input_sequence=flat_sequence,
                                        snippets=snippets)

            # Get the sequence, and show the de-anonymized and flattened
            # versions
            predicted_sequence = results[0]
            print(predicted_sequence)

            # Execute the query and show the results

            # Update the available snippets, etc.

        utterance = input("> ")
Example #3
0
    def predict_with_gold_queries(self,
                                  interaction,
                                  max_generation_length,
                                  feed_gold_query=False):
        """ Predicts SQL queries for an interaction.

        Inputs:
            interaction (Interaction): Interaction to predict for.
            feed_gold_query (bool): Whether or not to feed the gold token to the
                generation step.
        """
        assert self.params.discourse_level_lstm

        dy.renew_cg()

        predictions = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        decoder_states = []

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        for utterance in interaction.gold_utterances():
            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Encode the utterance, and update the discourse-level states
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence,
                lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]),
                dropout_amount=self.dropout)

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states,
                self.dropout)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(
                        available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query,
                                                     available_snippets)

            prediction = self.predict_turn(final_utterance_state,
                                           utterance_states,
                                           max_generation_length,
                                           gold_query=utterance.gold_query(),
                                           snippets=snippets,
                                           input_sequence=flat_sequence,
                                           feed_gold_tokens=feed_gold_query)
            decoder_states = prediction[3]
            predictions.append(prediction)

        return predictions
Example #4
0
    def predict_with_predicted_queries(self,
                                       interaction,
                                       max_generation_length,
                                       syntax_restrict=True):
        """ Predicts an interaction, using the predicted queries to get snippets."""
        assert self.params.discourse_level_lstm

        dy.renew_cg()

        predictions = []

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        interaction.start_interaction()
        while not interaction.done():
            utterance = interaction.next_utterance()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            input_sequence = utterance.input_sequence()
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence, lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]))

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                snippets = self._encode_snippets(previous_query,
                                                 available_snippets)

            results = self.predict_turn(final_utterance_state,
                                        utterance_states,
                                        max_generation_length,
                                        input_sequence=flat_sequence,
                                        snippets=snippets)

            predicted_sequence = results[0]
            predictions.append(results)

            # Update things necessary for using predicted queries
            anonymized_sequence = utterance.remove_snippets(
                predicted_sequence)[:-1]
            flat_sequence = utterance.flatten_sequence(predicted_sequence)

            if not syntax_restrict or sql_util.executable(
                    flat_sequence,
                    username=self.params.database_username,
                    password=self.params.database_password,
                    timeout=self.params.database_timeout):
                utterance.set_predicted_query(
                    interaction.remove_snippets(predicted_sequence))
                interaction.add_utterance(
                    utterance,
                    anonymized_sequence,
                    previous_snippets=utterance.snippets())

            else:
                utterance.set_predicted_query(utterance.previous_query())
                interaction.add_utterance(
                    utterance,
                    utterance.previous_query(),
                    previous_snippets=utterance.snippets())

        return predictions
Example #5
0
    def train(self,
              interaction,
              max_generation_length,
              snippet_alignment_probability=1.):
        """ Trains the interaction-level model on a single interaction.

        Inputs:
            interaction (Interaction): The interaction to train on.
            learning_rate (float): Learning rate to use.
            snippet_keep_age (int): Age of oldest snippets to use.
            snippet_alignment_probability (float): The probability that a snippet will
                be used in constructing the gold sequence.
        """
        assert self.params.discourse_level_lstm

        dy.renew_cg()

        losses = []
        total_gold_tokens = 0

        input_hidden_states = []
        input_sequences = []
        final_utterance_state = None

        decoder_states = []

        discourse_state, discourse_lstm_states = self._initialize_discourse_states(
        )

        # 指示是否开始新的轮次
        new_turn = True

        for utterance_index, utterance in enumerate(
                interaction.gold_utterances()):
            if interaction.identifier in LIMITED_INTERACTIONS \
                    and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]:
                break

            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Get the gold query: reconstruct if the alignment probability
            # is less than one
            if snippet_alignment_probability < 1.:
                gold_query = sql_util.add_snippets_to_query(
                    available_snippets,
                    utterance.contained_entities(),
                    utterance.anonymized_gold_query(),
                    prob_align=snippet_alignment_probability) + [
                        vocab.EOS_TOK
                    ]
            else:
                gold_query = utterance.gold_query()
                if self.params.copy:
                    gold_copy = utterance.gold_copy()
                else:
                    gold_copy = None

            # Encode the utterance, and update the discourse-level states
            final_utterance_state, utterance_states = self.utterance_encoder(
                input_sequence,
                lambda token: dy.concatenate(
                    [self.input_embedder(token), discourse_state]),
                dropout_amount=self.dropout)

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            _, discourse_state, discourse_lstm_states = du.forward_one_multilayer(
                final_utterance_state[1][0], discourse_lstm_states,
                self.dropout)

            flat_sequence = []
            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))
            for utt in input_sequences[-num_utterances_to_keep:]:
                flat_sequence.extend(utt)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(
                        available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query,
                                                     available_snippets)

            if len(gold_query) <= max_generation_length \
                    and len(previous_query) <= max_generation_length:
                #print("=====")
                print(utterance_index)
                prediction = self.predict_turn(final_utterance_state,
                                               utterance_states,
                                               max_generation_length,
                                               gold_query=gold_query,
                                               snippets=snippets,
                                               input_sequence=flat_sequence,
                                               feed_gold_tokens=True,
                                               training=True,
                                               first_utterance=new_turn,
                                               gold_copy=gold_copy)
                new_turn = False
                loss = prediction[1]
                decoder_states = prediction[3]
                total_gold_tokens += len(gold_query)
                losses.append(loss)
            else:
                # Break if previous decoder snippet encoding -- because the previous
                # sequence was too long to run the decoder.
                if self.params.previous_decoder_snippet_encoding:
                    break
                continue

        if losses:
            average_loss = dy.esum(losses) / total_gold_tokens

            # Renormalize so the effect is normalized by the batch size.
            normalized_loss = average_loss
            if self.params.reweight_batch:
                normalized_loss = len(losses) * average_loss / \
                    float(self.params.batch_size)
            normalized_loss.forward()
            normalized_loss.backward()
            self.trainer.update()
            loss_scalar = normalized_loss.value()
        else:
            loss_scalar = 0.

        return loss_scalar
Example #6
0
    def __call__(self,
                 final_encoder_state,
                 encoder_states,
                 max_generation_length,
                 snippets=None,
                 gold_sequence=None,
                 input_sequence=None,
                 dropout_amount=0.):
        """ Generates a sequence. """
        index = 0

        context_vector_size = self.token_predictor.attention_module.value_size

        # Decoder states: just the initialized decoder.
        # Current input to decoder: phi(start_token) ; zeros the size of the
        # context vector
        predictions = []
        sequence = []
        probability = 1.

        decoder_states = self._initialize_decoder_lstm(final_encoder_state)
        decoder_input = dy.concatenate(
            [self.start_token_embedding,
             dy.zeroes((context_vector_size, ))])

        continue_generating = True

        while continue_generating:
            if len(sequence) == 0 or sequence[-1] != EOS_TOK:
                _, decoder_state, decoder_states = du.forward_one_multilayer(
                    decoder_input, decoder_states, dropout_amount)
                prediction_input = PredictionInput(
                    decoder_state=decoder_state,
                    input_hidden_states=encoder_states,
                    snippets=snippets,
                    input_sequence=input_sequence)
                prediction = self.token_predictor(
                    prediction_input, dropout_amount=dropout_amount)

                predictions.append(prediction)

                if gold_sequence:
                    decoder_input = dy.concatenate([
                        self.output_embedder.bow_snippets(
                            gold_sequence[index], snippets),
                        prediction.attention_results.vector
                    ])
                    sequence.append(gold_sequence[index])

                    if index >= len(gold_sequence) - 1:
                        continue_generating = False
                else:
                    probabilities = np.transpose(
                        dy.softmax(prediction.scores).npvalue()).tolist()[0]
                    distribution_map = prediction.aligned_tokens

                    # Get a new probabilities and distribution_map consolidating
                    # duplicates
                    distribution_map, probabilities = flatten_distribution(
                        distribution_map, probabilities)

                    # Modify the probability distribution so that the UNK token can
                    # never be produced
                    probabilities[distribution_map.index(UNK_TOK)] = 0.
                    argmax_index = int(np.argmax(probabilities))

                    argmax_token = distribution_map[argmax_index]
                    sequence.append(argmax_token)

                    decoder_input = dy.concatenate([
                        self.output_embedder.bow_snippets(
                            argmax_token, snippets),
                        prediction.attention_results.vector
                    ])
                    probability *= probabilities[argmax_index]

                    continue_generating = False
                    if index < max_generation_length and argmax_token != EOS_TOK:
                        continue_generating = True

            index += 1

        return SQLPrediction(predictions, sequence, probability)
Example #7
0
    def __call__(self,
                 final_encoder_state,
                 encoder_states,
                 max_generation_length,
                 snippets=None,
                 gold_sequence=None,
                 input_sequence=None,
                 dropout_amount=0.,
                 controller=None,
                 first_utterance=True,
                 gold_copy=None):
        """ Generates a sequence. """
        index = 0

        context_vector_size = self.token_predictor.attention_module.value_size
        decoder_state_size = self.decoder_state_size

        state_stack = []
        pick_loss = None

        # Decoder states: just the initialized decoder.
        # Current input to decoder: phi(start_token) ; zeros the size of the
        # context vector
        predictions = []
        sequence = []
        probability = 1.

        decoder_states = self._initialize_decoder_lstm(final_encoder_state)
        decoder_input = dy.concatenate([self.start_token_embedding,
                                        dy.zeroes((context_vector_size,)),
                                        dy.zeros((decoder_state_size,))])

        continue_generating = True
        if controller:
            controller.initialize()

        # TODO: 一开始通过LAST_DECODER_STATES和当前ENCODER STATES来预测起始值,然后可以把之前的扔了
        if (not first_utterance) and gold_copy:
            encoder_state = final_encoder_state[1][-1]
            intermediate = du.linear_transform(encoder_state, self.pick_pos_param)
            #print("intermediate: ", intermediate.dim()[0])
            #print("decoder: ", self.last_decoder_states[0].dim()[0])
            #print("length of last decoder states:", len(self.last_decoder_states))
            score = [intermediate * decoder_state for decoder_state in self.last_decoder_states]
            score = dy.concatenate(score)
            #print(gold_copy)
            #print(gold_sequence)
            #print("============")
            pick_loss = dy.pickneglogsoftmax(score, gold_copy[0])

            self.last_decoder_states = [final_encoder_state[1][-1]]
            start_pos = gold_copy[0]
            cnt = 0
            if start_pos == 0:
                index = 0
            else:
                for num, token in enumerate(gold_sequence):
                    if token == '<C>' or token == '<S>':
                        cnt += 1
                        if cnt == start_pos:
                            index = num
                            break
                    _, decoder_state, decoder_states = du.forward_one_multilayer(
                        decoder_input, decoder_states, dropout_amount)
                    prediction_input = PredictionInput(decoder_state=decoder_state,
                                                   input_hidden_states=encoder_states,
                                                   snippets=snippets,
                                                   input_sequence=input_sequence)
                    prediction = self.token_predictor(prediction_input,
                                                  dropout_amount=dropout_amount,
                                                  controller=controller)
                    token = gold_sequence[num]
                    if token == '<C>' or token == '<S>':
                        state_stack.append(decoder_state)
                        self.last_decoder_states.append(decoder_state)
                    if token == '<EOT>' and state_stack:
                        decoder_input = dy.concatenate([self.output_embedder(token), prediction.attention_results.vector,
                                                        state_stack.pop(-1)])
                    else:
                        decoder_input = dy.concatenate(
                            [self.output_embedder.bow_snippets(token,
                                                               snippets),
                             prediction.attention_results.vector,
                             dy.zeros((decoder_state_size,))])
                    controller.update(token)
        else:
            self.last_decoder_states = [final_encoder_state[1][-1]]

        while continue_generating:
            if len(sequence) == 0 or sequence[-1] != EOS_TOK:
                _, decoder_state, decoder_states = du.forward_one_multilayer(
                    decoder_input, decoder_states, dropout_amount)
                if gold_sequence:
                    truth_label = controller.vocab.token_to_label(gold_sequence[index])
                    #print("Ground Truth: ", gold_sequence[index], "label:", truth_label)
                prediction_input = PredictionInput(decoder_state=decoder_state,
                                                   input_hidden_states=encoder_states,
                                                   snippets=snippets,
                                                   input_sequence=input_sequence)
                prediction = self.token_predictor(prediction_input,
                                                  dropout_amount=dropout_amount,
                                                  controller=controller)

                predictions.append(prediction)

                if gold_sequence:
                    token = gold_sequence[index]
                    if token == '<C>' or token == '<S>':
                        state_stack.append(decoder_state)
                        self.last_decoder_states.append(decoder_state)
                    if token == '<EOT>' and state_stack:
                        decoder_input = dy.concatenate([self.output_embedder(token), prediction.attention_results.vector,
                                                        state_stack.pop(-1)])
                    else:
                        decoder_input = dy.concatenate(
                        [self.output_embedder.bow_snippets(token,
                                                           snippets),
                         prediction.attention_results.vector,
                         dy.zeros((decoder_state_size,))])
                    sequence.append(token)
                    if controller:
                        #print(gold_sequence[index])
                        controller.update(gold_sequence[index])

                    if index >= len(gold_sequence) - 1:
                        continue_generating = False
                else:
                    probabilities = np.transpose(dy.softmax(
                        prediction.scores).npvalue()).tolist()[0]
                    distribution_map = prediction.aligned_tokens

                    # Get a new probabilities and distribution_map consolidating
                    # duplicates
                    distribution_map, probabilities = flatten_distribution(distribution_map,
                                                                           probabilities)

                    # Modify the probability distribution so that the UNK token can
                    # never be produced
                    probabilities[distribution_map.index(UNK_TOK)] = 0.
                    argmax_index = int(np.argmax(probabilities))

                    argmax_token = distribution_map[argmax_index]
                    #print(len(probabilities))
                    if controller:
                        controller.update(argmax_token)
                    sequence.append(argmax_token)

                    decoder_input = dy.concatenate(
                        [self.output_embedder.bow_snippets(argmax_token, snippets),
                         prediction.attention_results.vector])
                    probability *= probabilities[argmax_index]

                    continue_generating = False
                    if index < max_generation_length and argmax_token != EOS_TOK:
                        continue_generating = True

            index += 1

        return SQLPrediction(predictions,
                             sequence,
                             probability), pick_loss