def bow_snippets(token, snippets, output_embedder, input_schema): """ Bag of words embedding for snippets""" assert snippet_handler.is_snippet(token) and snippets snippet_sequence = [] for snippet in snippets: if snippet.name == token: snippet_sequence = snippet.sequence break assert snippet_sequence if input_schema: snippet_embeddings = [] for output_token in snippet_sequence: assert output_embedder.in_vocabulary( output_token) or input_schema.in_vocabulary(output_token, surface_form=True) if output_embedder.in_vocabulary(output_token): snippet_embeddings.append(output_embedder(output_token)) else: snippet_embeddings.append( input_schema.column_name_embedder(output_token, surface_form=True)) else: snippet_embeddings = [ output_embedder(subtoken) for subtoken in snippet_sequence ] snippet_embeddings = paddle.stack( snippet_embeddings, axis=0) # len(snippet_sequence) x emb_size return paddle.mean(snippet_embeddings, axis=0) # emb_size
def forward(self, token): assert isinstance(token, int) or not snippet_handler.is_snippet( token ), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets" if self.in_vocabulary(token): index_list = paddle.to_tensor(self.vocab_token_lookup(token), 'int64') return self.token_embedding_matrix(index_list).squeeze() else: index_list = paddle.to_tensor(self.unknown_token_id, 'int64') return self.token_embedding_matrix(index_list).squeeze()
def get_output_token_embedding(self, output_token, input_schema, snippets): if self.params.use_snippets and snippet_handler.is_snippet(output_token): output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema) else: if input_schema: assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True) if self.output_embedder.in_vocabulary(output_token): output_token_embedding = self.output_embedder(output_token) else: output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True) else: output_token_embedding = self.output_embedder(output_token) return output_token_embedding
def forward(self, token): assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets" if self.in_vocabulary(token): index_list = torch.LongTensor([self.vocab_token_lookup(token)]) if self.token_embedding_matrix.weight.is_cuda: index_list = index_list.cuda() return self.token_embedding_matrix(index_list).squeeze() elif self.anonymizer and self.anonymizer.is_anon_tok(token): index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)]) if self.token_embedding_matrix.weight.is_cuda: index_list = index_list.cuda() return self.entity_embedding_matrix(index_list).squeeze() else: index_list = torch.LongTensor([self.unknown_token_id]) if self.token_embedding_matrix.weight.is_cuda: index_list = index_list.cuda() return self.token_embedding_matrix(index_list).squeeze()
def predict_turn(self, utterance_final_state, input_hidden_states, schema_states, max_generation_length, gold_query=None, snippets=None, input_sequence=None, previous_queries=None, previous_query_states=None, input_schema=None, feed_gold_tokens=False, training=False): """ Gets a prediction for a single turn -- calls decoder and updates loss, etc. TODO: this can probably be split into two methods, one that just predicts and another that computes the loss. """ predicted_sequence = [] fed_sequence = [] loss = None token_accuracy = 0. if self.params.use_encoder_attention: schema_attention = self.utterance2schema_attention_module(torch.stack(schema_states,dim=0), input_hidden_states).vector # input_value_size x len(schema) utterance_attention = self.schema2utterance_attention_module(torch.stack(input_hidden_states,dim=0), schema_states).vector # schema_value_size x len(input) if schema_attention.dim() == 1: schema_attention = schema_attention.unsqueeze(1) if utterance_attention.dim() == 1: utterance_attention = utterance_attention.unsqueeze(1) new_schema_states = torch.cat([torch.stack(schema_states, dim=1), schema_attention], dim=0) # (input_value_size+schema_value_size) x len(schema) schema_states = list(torch.split(new_schema_states, split_size_or_sections=1, dim=1)) schema_states = [schema_state.squeeze() for schema_state in schema_states] new_input_hidden_states = torch.cat([torch.stack(input_hidden_states, dim=1), utterance_attention], dim=0) # (input_value_size+schema_value_size) x len(input) input_hidden_states = list(torch.split(new_input_hidden_states, split_size_or_sections=1, dim=1)) input_hidden_states = [input_hidden_state.squeeze() for input_hidden_state in input_hidden_states] # bi-lstm over schema_states and input_hidden_states (embedder is an identify function) if self.params.use_schema_encoder_2: final_schema_state, schema_states = self.schema_encoder_2(schema_states, lambda x: x, dropout_amount=self.dropout) final_utterance_state, input_hidden_states = self.utterance_encoder_2(input_hidden_states, lambda x: x, dropout_amount=self.dropout) if feed_gold_tokens: decoder_results = self.decoder(utterance_final_state, input_hidden_states, schema_states, max_generation_length, gold_sequence=gold_query, input_sequence=input_sequence, previous_queries=previous_queries, previous_query_states=previous_query_states, input_schema=input_schema, snippets=snippets, dropout_amount=self.dropout) all_scores = [] all_alignments = [] for prediction in decoder_results.predictions: scores = F.softmax(prediction.scores, dim=0) alignments = prediction.aligned_tokens if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0: query_scores = F.softmax(prediction.query_scores, dim=0) copy_switch = prediction.copy_switch scores = torch.cat([scores * (1 - copy_switch), query_scores * copy_switch], dim=0) alignments = alignments + prediction.query_tokens all_scores.append(scores) all_alignments.append(alignments) # Compute the loss gold_sequence = gold_query loss = torch_utils.compute_loss(gold_sequence, all_scores, all_alignments, get_token_indices) if not training: predicted_sequence = torch_utils.get_seq_from_scores(all_scores, all_alignments) token_accuracy = torch_utils.per_token_accuracy(gold_sequence, predicted_sequence) fed_sequence = gold_sequence else: decoder_results = self.decoder(utterance_final_state, input_hidden_states, schema_states, max_generation_length, input_sequence=input_sequence, previous_queries=previous_queries, previous_query_states=previous_query_states, input_schema=input_schema, snippets=snippets, dropout_amount=self.dropout) predicted_sequence = decoder_results.sequence fed_sequence = predicted_sequence decoder_states = [pred.decoder_state for pred in decoder_results.predictions] # fed_sequence contains EOS, which we don't need when encoding snippets. # also ignore the first state, as it contains the BEG encoding. for token, state in zip(fed_sequence[:-1], decoder_states[1:]): if snippet_handler.is_snippet(token): snippet_length = 0 for snippet in snippets: if snippet.name == token: snippet_length = len(snippet.sequence) break assert snippet_length > 0 decoder_states.extend([state for _ in range(snippet_length)]) else: decoder_states.append(state) return (predicted_sequence, loss, token_accuracy, decoder_states, decoder_results)
def predict_turn(self, utterance_final_state, input_hidden_states, schema_states, max_generation_length, gold_query=None, snippets=None, input_sequence=None, input_schema=None, feed_gold_tokens=False, training=False): """ Gets a prediction for a single turn -- calls decoder and updates loss, etc. TODO: this can probably be split into two methods, one that just predicts and another that computes the loss. """ predicted_sequence = [] fed_sequence = [] loss = None token_accuracy = 0. if feed_gold_tokens: decoder_results = self.decoder(utterance_final_state, input_hidden_states, schema_states, max_generation_length, gold_sequence=gold_query, input_sequence=input_sequence, input_schema=input_schema, snippets=snippets, dropout_amount=self.dropout) all_scores = [step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions ] # Compute the loss gold_sequence = gold_query loss = torch_utils.compute_loss(gold_sequence, all_scores, all_alignments, get_token_indices) if not training: predicted_sequence = torch_utils.get_seq_from_scores( all_scores, all_alignments) token_accuracy = torch_utils.per_token_accuracy( gold_sequence, predicted_sequence) fed_sequence = gold_sequence else: decoder_results = self.decoder(utterance_final_state, input_hidden_states, schema_states, max_generation_length, input_sequence=input_sequence, input_schema=input_schema, snippets=snippets, dropout_amount=self.dropout) predicted_sequence = decoder_results.sequence fed_sequence = predicted_sequence decoder_states = [ pred.decoder_state for pred in decoder_results.predictions ] # fed_sequence contains EOS, which we don't need when encoding snippets. # also ignore the first state, as it contains the BEG encoding. for token, state in zip(fed_sequence[:-1], decoder_states[1:]): if snippet_handler.is_snippet(token): snippet_length = 0 for snippet in snippets: if snippet.name == token: snippet_length = len(snippet.sequence) break assert snippet_length > 0 decoder_states.extend([state for _ in range(snippet_length)]) else: decoder_states.append(state) return (predicted_sequence, loss, token_accuracy, decoder_states, decoder_results)