def _encode_with_discourse_lstm(self, utterances): """ Encodes the utterances using a discourse-level LSTM, instead of concatenating. Inputs: utterances (list of list of str): Utterances. """ hidden_states = [] discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) final_state = None for utterance in utterances: final_state, utterance_states = self.utterance_encoder( utterance, lambda token: dy.concatenate( [self.input_embedder(token), discourse_state]), dropout_amount=self.dropout) hidden_states.extend(utterance_states) _, discourse_state, discourse_lstm_states = du.forward_one_multilayer( final_state, discourse_lstm_states, self.dropout) return final_state, hidden_states
def interactive_prediction(self, anonymizer): """Interactive prediction. Inputs: anonymizer (Anonymizer): Anonymizer to use for user's input. """ dy.renew_cg() snippet_bank = [] anonymization_dictionary = {} previous_query = [] input_hidden_states = [] input_sequences = [] final_utterance_state = None discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) utterance = "show me flights from new york to boston" # input("> ") while utterance.lower() not in END_OF_INTERACTION: # First, need to normalize the utterance and get an anonymization # dictionary. tokenized_sequence = tokenizers.nl_tokenize(utterance) available_snippets = [ snippet for snippet in snippet_bank if snippet.index <= 1 ] sequence_to_use = tokenized_sequence # TODO: implement date normalization if self.params.anonymize: sequence_to_use = anonymizer.anonymize( tokenized_sequence, anonymization_dictionary, ANON_INPUT_KEY, add_new_anon_toks=True) # Now we encode the sequence final_utterance_state, utterance_states = self.utterance_encoder( sequence_to_use, lambda token: dy.concatenate( [self.input_embedder(token), discourse_state])) input_hidden_states.extend(utterance_states) input_sequences.append(sequence_to_use) # Now update the discourse state _, discourse_state, discourse_lstm_states = du.forward_one_multilayer( final_utterance_state[1][0], discourse_lstm_states) # Add positional embeddings flat_sequence = [] num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences)) for utt in input_sequences[-num_utterances_to_keep:]: flat_sequence.extend(utt) if self.params.state_positional_embeddings: utterance_states, flat_sequence = self._add_positional_embeddings( input_hidden_states, input_sequences) # Encode the snippets if self.params.use_snippets: snippets = self._encode_snippets(previous_query, available_snippets) # Predict a result results = self.predict_turn(final_utterance_state, utterance_states, self.params.eval_maximum_sql_length, input_sequence=flat_sequence, snippets=snippets) # Get the sequence, and show the de-anonymized and flattened # versions predicted_sequence = results[0] print(predicted_sequence) # Execute the query and show the results # Update the available snippets, etc. utterance = input("> ")
def predict_with_gold_queries(self, interaction, max_generation_length, feed_gold_query=False): """ Predicts SQL queries for an interaction. Inputs: interaction (Interaction): Interaction to predict for. feed_gold_query (bool): Whether or not to feed the gold token to the generation step. """ assert self.params.discourse_level_lstm dy.renew_cg() predictions = [] input_hidden_states = [] input_sequences = [] final_utterance_state = None decoder_states = [] discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) for utterance in interaction.gold_utterances(): input_sequence = utterance.input_sequence() available_snippets = utterance.snippets() previous_query = utterance.previous_query() # Encode the utterance, and update the discourse-level states final_utterance_state, utterance_states = self.utterance_encoder( input_sequence, lambda token: dy.concatenate( [self.input_embedder(token), discourse_state]), dropout_amount=self.dropout) input_hidden_states.extend(utterance_states) input_sequences.append(input_sequence) _, discourse_state, discourse_lstm_states = du.forward_one_multilayer( final_utterance_state[1][0], discourse_lstm_states, self.dropout) flat_sequence = [] num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences)) for utt in input_sequences[-num_utterances_to_keep:]: flat_sequence.extend(utt) if self.params.state_positional_embeddings: utterance_states, flat_sequence = self._add_positional_embeddings( input_hidden_states, input_sequences) snippets = None if self.params.use_snippets: if self.params.previous_decoder_snippet_encoding: snippets = encode_snippets_with_states( available_snippets, decoder_states) else: snippets = self._encode_snippets(previous_query, available_snippets) prediction = self.predict_turn(final_utterance_state, utterance_states, max_generation_length, gold_query=utterance.gold_query(), snippets=snippets, input_sequence=flat_sequence, feed_gold_tokens=feed_gold_query) decoder_states = prediction[3] predictions.append(prediction) return predictions
def predict_with_predicted_queries(self, interaction, max_generation_length, syntax_restrict=True): """ Predicts an interaction, using the predicted queries to get snippets.""" assert self.params.discourse_level_lstm dy.renew_cg() predictions = [] input_hidden_states = [] input_sequences = [] final_utterance_state = None discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) interaction.start_interaction() while not interaction.done(): utterance = interaction.next_utterance() available_snippets = utterance.snippets() previous_query = utterance.previous_query() input_sequence = utterance.input_sequence() final_utterance_state, utterance_states = self.utterance_encoder( input_sequence, lambda token: dy.concatenate( [self.input_embedder(token), discourse_state])) input_hidden_states.extend(utterance_states) input_sequences.append(input_sequence) _, discourse_state, discourse_lstm_states = du.forward_one_multilayer( final_utterance_state[1][0], discourse_lstm_states) flat_sequence = [] num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences)) for utt in input_sequences[-num_utterances_to_keep:]: flat_sequence.extend(utt) if self.params.state_positional_embeddings: utterance_states, flat_sequence = self._add_positional_embeddings( input_hidden_states, input_sequences) snippets = None if self.params.use_snippets: snippets = self._encode_snippets(previous_query, available_snippets) results = self.predict_turn(final_utterance_state, utterance_states, max_generation_length, input_sequence=flat_sequence, snippets=snippets) predicted_sequence = results[0] predictions.append(results) # Update things necessary for using predicted queries anonymized_sequence = utterance.remove_snippets( predicted_sequence)[:-1] flat_sequence = utterance.flatten_sequence(predicted_sequence) if not syntax_restrict or sql_util.executable( flat_sequence, username=self.params.database_username, password=self.params.database_password, timeout=self.params.database_timeout): utterance.set_predicted_query( interaction.remove_snippets(predicted_sequence)) interaction.add_utterance( utterance, anonymized_sequence, previous_snippets=utterance.snippets()) else: utterance.set_predicted_query(utterance.previous_query()) interaction.add_utterance( utterance, utterance.previous_query(), previous_snippets=utterance.snippets()) return predictions
def train(self, interaction, max_generation_length, snippet_alignment_probability=1.): """ Trains the interaction-level model on a single interaction. Inputs: interaction (Interaction): The interaction to train on. learning_rate (float): Learning rate to use. snippet_keep_age (int): Age of oldest snippets to use. snippet_alignment_probability (float): The probability that a snippet will be used in constructing the gold sequence. """ assert self.params.discourse_level_lstm dy.renew_cg() losses = [] total_gold_tokens = 0 input_hidden_states = [] input_sequences = [] final_utterance_state = None decoder_states = [] discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) # 指示是否开始新的轮次 new_turn = True for utterance_index, utterance in enumerate( interaction.gold_utterances()): if interaction.identifier in LIMITED_INTERACTIONS \ and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]: break input_sequence = utterance.input_sequence() available_snippets = utterance.snippets() previous_query = utterance.previous_query() # Get the gold query: reconstruct if the alignment probability # is less than one if snippet_alignment_probability < 1.: gold_query = sql_util.add_snippets_to_query( available_snippets, utterance.contained_entities(), utterance.anonymized_gold_query(), prob_align=snippet_alignment_probability) + [ vocab.EOS_TOK ] else: gold_query = utterance.gold_query() if self.params.copy: gold_copy = utterance.gold_copy() else: gold_copy = None # Encode the utterance, and update the discourse-level states final_utterance_state, utterance_states = self.utterance_encoder( input_sequence, lambda token: dy.concatenate( [self.input_embedder(token), discourse_state]), dropout_amount=self.dropout) input_hidden_states.extend(utterance_states) input_sequences.append(input_sequence) _, discourse_state, discourse_lstm_states = du.forward_one_multilayer( final_utterance_state[1][0], discourse_lstm_states, self.dropout) flat_sequence = [] num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences)) for utt in input_sequences[-num_utterances_to_keep:]: flat_sequence.extend(utt) if self.params.state_positional_embeddings: utterance_states, flat_sequence = self._add_positional_embeddings( input_hidden_states, input_sequences) snippets = None if self.params.use_snippets: if self.params.previous_decoder_snippet_encoding: snippets = encode_snippets_with_states( available_snippets, decoder_states) else: snippets = self._encode_snippets(previous_query, available_snippets) if len(gold_query) <= max_generation_length \ and len(previous_query) <= max_generation_length: #print("=====") print(utterance_index) prediction = self.predict_turn(final_utterance_state, utterance_states, max_generation_length, gold_query=gold_query, snippets=snippets, input_sequence=flat_sequence, feed_gold_tokens=True, training=True, first_utterance=new_turn, gold_copy=gold_copy) new_turn = False loss = prediction[1] decoder_states = prediction[3] total_gold_tokens += len(gold_query) losses.append(loss) else: # Break if previous decoder snippet encoding -- because the previous # sequence was too long to run the decoder. if self.params.previous_decoder_snippet_encoding: break continue if losses: average_loss = dy.esum(losses) / total_gold_tokens # Renormalize so the effect is normalized by the batch size. normalized_loss = average_loss if self.params.reweight_batch: normalized_loss = len(losses) * average_loss / \ float(self.params.batch_size) normalized_loss.forward() normalized_loss.backward() self.trainer.update() loss_scalar = normalized_loss.value() else: loss_scalar = 0. return loss_scalar
def __call__(self, final_encoder_state, encoder_states, max_generation_length, snippets=None, gold_sequence=None, input_sequence=None, dropout_amount=0.): """ Generates a sequence. """ index = 0 context_vector_size = self.token_predictor.attention_module.value_size # Decoder states: just the initialized decoder. # Current input to decoder: phi(start_token) ; zeros the size of the # context vector predictions = [] sequence = [] probability = 1. decoder_states = self._initialize_decoder_lstm(final_encoder_state) decoder_input = dy.concatenate( [self.start_token_embedding, dy.zeroes((context_vector_size, ))]) continue_generating = True while continue_generating: if len(sequence) == 0 or sequence[-1] != EOS_TOK: _, decoder_state, decoder_states = du.forward_one_multilayer( decoder_input, decoder_states, dropout_amount) prediction_input = PredictionInput( decoder_state=decoder_state, input_hidden_states=encoder_states, snippets=snippets, input_sequence=input_sequence) prediction = self.token_predictor( prediction_input, dropout_amount=dropout_amount) predictions.append(prediction) if gold_sequence: decoder_input = dy.concatenate([ self.output_embedder.bow_snippets( gold_sequence[index], snippets), prediction.attention_results.vector ]) sequence.append(gold_sequence[index]) if index >= len(gold_sequence) - 1: continue_generating = False else: probabilities = np.transpose( dy.softmax(prediction.scores).npvalue()).tolist()[0] distribution_map = prediction.aligned_tokens # Get a new probabilities and distribution_map consolidating # duplicates distribution_map, probabilities = flatten_distribution( distribution_map, probabilities) # Modify the probability distribution so that the UNK token can # never be produced probabilities[distribution_map.index(UNK_TOK)] = 0. argmax_index = int(np.argmax(probabilities)) argmax_token = distribution_map[argmax_index] sequence.append(argmax_token) decoder_input = dy.concatenate([ self.output_embedder.bow_snippets( argmax_token, snippets), prediction.attention_results.vector ]) probability *= probabilities[argmax_index] continue_generating = False if index < max_generation_length and argmax_token != EOS_TOK: continue_generating = True index += 1 return SQLPrediction(predictions, sequence, probability)
def __call__(self, final_encoder_state, encoder_states, max_generation_length, snippets=None, gold_sequence=None, input_sequence=None, dropout_amount=0., controller=None, first_utterance=True, gold_copy=None): """ Generates a sequence. """ index = 0 context_vector_size = self.token_predictor.attention_module.value_size decoder_state_size = self.decoder_state_size state_stack = [] pick_loss = None # Decoder states: just the initialized decoder. # Current input to decoder: phi(start_token) ; zeros the size of the # context vector predictions = [] sequence = [] probability = 1. decoder_states = self._initialize_decoder_lstm(final_encoder_state) decoder_input = dy.concatenate([self.start_token_embedding, dy.zeroes((context_vector_size,)), dy.zeros((decoder_state_size,))]) continue_generating = True if controller: controller.initialize() # TODO: 一开始通过LAST_DECODER_STATES和当前ENCODER STATES来预测起始值,然后可以把之前的扔了 if (not first_utterance) and gold_copy: encoder_state = final_encoder_state[1][-1] intermediate = du.linear_transform(encoder_state, self.pick_pos_param) #print("intermediate: ", intermediate.dim()[0]) #print("decoder: ", self.last_decoder_states[0].dim()[0]) #print("length of last decoder states:", len(self.last_decoder_states)) score = [intermediate * decoder_state for decoder_state in self.last_decoder_states] score = dy.concatenate(score) #print(gold_copy) #print(gold_sequence) #print("============") pick_loss = dy.pickneglogsoftmax(score, gold_copy[0]) self.last_decoder_states = [final_encoder_state[1][-1]] start_pos = gold_copy[0] cnt = 0 if start_pos == 0: index = 0 else: for num, token in enumerate(gold_sequence): if token == '<C>' or token == '<S>': cnt += 1 if cnt == start_pos: index = num break _, decoder_state, decoder_states = du.forward_one_multilayer( decoder_input, decoder_states, dropout_amount) prediction_input = PredictionInput(decoder_state=decoder_state, input_hidden_states=encoder_states, snippets=snippets, input_sequence=input_sequence) prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount, controller=controller) token = gold_sequence[num] if token == '<C>' or token == '<S>': state_stack.append(decoder_state) self.last_decoder_states.append(decoder_state) if token == '<EOT>' and state_stack: decoder_input = dy.concatenate([self.output_embedder(token), prediction.attention_results.vector, state_stack.pop(-1)]) else: decoder_input = dy.concatenate( [self.output_embedder.bow_snippets(token, snippets), prediction.attention_results.vector, dy.zeros((decoder_state_size,))]) controller.update(token) else: self.last_decoder_states = [final_encoder_state[1][-1]] while continue_generating: if len(sequence) == 0 or sequence[-1] != EOS_TOK: _, decoder_state, decoder_states = du.forward_one_multilayer( decoder_input, decoder_states, dropout_amount) if gold_sequence: truth_label = controller.vocab.token_to_label(gold_sequence[index]) #print("Ground Truth: ", gold_sequence[index], "label:", truth_label) prediction_input = PredictionInput(decoder_state=decoder_state, input_hidden_states=encoder_states, snippets=snippets, input_sequence=input_sequence) prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount, controller=controller) predictions.append(prediction) if gold_sequence: token = gold_sequence[index] if token == '<C>' or token == '<S>': state_stack.append(decoder_state) self.last_decoder_states.append(decoder_state) if token == '<EOT>' and state_stack: decoder_input = dy.concatenate([self.output_embedder(token), prediction.attention_results.vector, state_stack.pop(-1)]) else: decoder_input = dy.concatenate( [self.output_embedder.bow_snippets(token, snippets), prediction.attention_results.vector, dy.zeros((decoder_state_size,))]) sequence.append(token) if controller: #print(gold_sequence[index]) controller.update(gold_sequence[index]) if index >= len(gold_sequence) - 1: continue_generating = False else: probabilities = np.transpose(dy.softmax( prediction.scores).npvalue()).tolist()[0] distribution_map = prediction.aligned_tokens # Get a new probabilities and distribution_map consolidating # duplicates distribution_map, probabilities = flatten_distribution(distribution_map, probabilities) # Modify the probability distribution so that the UNK token can # never be produced probabilities[distribution_map.index(UNK_TOK)] = 0. argmax_index = int(np.argmax(probabilities)) argmax_token = distribution_map[argmax_index] #print(len(probabilities)) if controller: controller.update(argmax_token) sequence.append(argmax_token) decoder_input = dy.concatenate( [self.output_embedder.bow_snippets(argmax_token, snippets), prediction.attention_results.vector]) probability *= probabilities[argmax_index] continue_generating = False if index < max_generation_length and argmax_token != EOS_TOK: continue_generating = True index += 1 return SQLPrediction(predictions, sequence, probability), pick_loss