コード例 #1
0
def bow_snippets(token, snippets, output_embedder, input_schema):
    """ Bag of words embedding for snippets"""
    assert snippet_handler.is_snippet(token) and snippets

    snippet_sequence = []
    for snippet in snippets:
        if snippet.name == token:
            snippet_sequence = snippet.sequence
            break
    assert snippet_sequence

    if input_schema:
        snippet_embeddings = []
        for output_token in snippet_sequence:
            assert output_embedder.in_vocabulary(
                output_token) or input_schema.in_vocabulary(output_token,
                                                            surface_form=True)
            if output_embedder.in_vocabulary(output_token):
                snippet_embeddings.append(output_embedder(output_token))
            else:
                snippet_embeddings.append(
                    input_schema.column_name_embedder(output_token,
                                                      surface_form=True))
    else:
        snippet_embeddings = [
            output_embedder(subtoken) for subtoken in snippet_sequence
        ]

    snippet_embeddings = paddle.stack(
        snippet_embeddings, axis=0)  # len(snippet_sequence) x emb_size
    return paddle.mean(snippet_embeddings, axis=0)  # emb_size
コード例 #2
0
    def forward(self, token):
        assert isinstance(token, int) or not snippet_handler.is_snippet(
            token
        ), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"

        if self.in_vocabulary(token):
            index_list = paddle.to_tensor(self.vocab_token_lookup(token),
                                          'int64')
            return self.token_embedding_matrix(index_list).squeeze()
        else:
            index_list = paddle.to_tensor(self.unknown_token_id, 'int64')
            return self.token_embedding_matrix(index_list).squeeze()
コード例 #3
0
ファイル: decoder.py プロジェクト: yuhaia/r2sql
 def get_output_token_embedding(self, output_token, input_schema, snippets):
     if self.params.use_snippets and snippet_handler.is_snippet(output_token):
         output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema)
     else:
         if input_schema:
             assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
             if self.output_embedder.in_vocabulary(output_token):
                 output_token_embedding = self.output_embedder(output_token)
             else:
                 output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
         else:
             output_token_embedding = self.output_embedder(output_token)
     return output_token_embedding
コード例 #4
0
    def forward(self, token):
        assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"

        if self.in_vocabulary(token):
            index_list = torch.LongTensor([self.vocab_token_lookup(token)])
            if self.token_embedding_matrix.weight.is_cuda:
                index_list = index_list.cuda()
            return self.token_embedding_matrix(index_list).squeeze()
        elif self.anonymizer and self.anonymizer.is_anon_tok(token):
            index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)])
            if self.token_embedding_matrix.weight.is_cuda:
                index_list = index_list.cuda()
            return self.entity_embedding_matrix(index_list).squeeze()
        else:
            index_list = torch.LongTensor([self.unknown_token_id])
            if self.token_embedding_matrix.weight.is_cuda:
                index_list = index_list.cuda()
            return self.token_embedding_matrix(index_list).squeeze()
コード例 #5
0
    def predict_turn(self,
                     utterance_final_state,
                     input_hidden_states,
                     schema_states,
                     max_generation_length,
                     gold_query=None,
                     snippets=None,
                     input_sequence=None,
                     previous_queries=None,
                     previous_query_states=None,
                     input_schema=None,
                     feed_gold_tokens=False,
                     training=False):
        """ Gets a prediction for a single turn -- calls decoder and updates loss, etc.

        TODO:  this can probably be split into two methods, one that just predicts
            and another that computes the loss.
        """
        predicted_sequence = []
        fed_sequence = []
        loss = None
        token_accuracy = 0.

        if self.params.use_encoder_attention:
            schema_attention = self.utterance2schema_attention_module(torch.stack(schema_states,dim=0), input_hidden_states).vector # input_value_size x len(schema)
            utterance_attention = self.schema2utterance_attention_module(torch.stack(input_hidden_states,dim=0), schema_states).vector # schema_value_size x len(input)

            if schema_attention.dim() == 1:
                schema_attention = schema_attention.unsqueeze(1)
            if utterance_attention.dim() == 1:
                utterance_attention = utterance_attention.unsqueeze(1)

            new_schema_states = torch.cat([torch.stack(schema_states, dim=1), schema_attention], dim=0) # (input_value_size+schema_value_size) x len(schema)
            schema_states = list(torch.split(new_schema_states, split_size_or_sections=1, dim=1))
            schema_states = [schema_state.squeeze() for schema_state in schema_states]

            new_input_hidden_states = torch.cat([torch.stack(input_hidden_states, dim=1), utterance_attention], dim=0) # (input_value_size+schema_value_size) x len(input)
            input_hidden_states = list(torch.split(new_input_hidden_states, split_size_or_sections=1, dim=1))
            input_hidden_states = [input_hidden_state.squeeze() for input_hidden_state in input_hidden_states]

            # bi-lstm over schema_states and input_hidden_states (embedder is an identify function)
            if self.params.use_schema_encoder_2:
                final_schema_state, schema_states = self.schema_encoder_2(schema_states, lambda x: x, dropout_amount=self.dropout)
                final_utterance_state, input_hidden_states = self.utterance_encoder_2(input_hidden_states, lambda x: x, dropout_amount=self.dropout)

        if feed_gold_tokens:
            decoder_results = self.decoder(utterance_final_state,
                                           input_hidden_states,
                                           schema_states,
                                           max_generation_length,
                                           gold_sequence=gold_query,
                                           input_sequence=input_sequence,
                                           previous_queries=previous_queries,
                                           previous_query_states=previous_query_states,
                                           input_schema=input_schema,
                                           snippets=snippets,
                                           dropout_amount=self.dropout)

            all_scores = []
            all_alignments = []
            for prediction in decoder_results.predictions:
                scores = F.softmax(prediction.scores, dim=0)
                alignments = prediction.aligned_tokens
                if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
                    query_scores = F.softmax(prediction.query_scores, dim=0)
                    copy_switch = prediction.copy_switch
                    scores = torch.cat([scores * (1 - copy_switch), query_scores * copy_switch], dim=0)
                    alignments = alignments + prediction.query_tokens

                all_scores.append(scores)
                all_alignments.append(alignments)

            # Compute the loss
            gold_sequence = gold_query

            loss = torch_utils.compute_loss(gold_sequence, all_scores, all_alignments, get_token_indices)
            if not training:
                predicted_sequence = torch_utils.get_seq_from_scores(all_scores, all_alignments)
                token_accuracy = torch_utils.per_token_accuracy(gold_sequence, predicted_sequence)
            fed_sequence = gold_sequence
        else:
            decoder_results = self.decoder(utterance_final_state,
                                           input_hidden_states,
                                           schema_states,
                                           max_generation_length,
                                           input_sequence=input_sequence,
                                           previous_queries=previous_queries,
                                           previous_query_states=previous_query_states,
                                           input_schema=input_schema,
                                           snippets=snippets,
                                           dropout_amount=self.dropout)
            predicted_sequence = decoder_results.sequence
            fed_sequence = predicted_sequence

        decoder_states = [pred.decoder_state for pred in decoder_results.predictions]

        # fed_sequence contains EOS, which we don't need when encoding snippets.
        # also ignore the first state, as it contains the BEG encoding.

        for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
            if snippet_handler.is_snippet(token):
                snippet_length = 0
                for snippet in snippets:
                    if snippet.name == token:
                        snippet_length = len(snippet.sequence)
                        break
                assert snippet_length > 0
                decoder_states.extend([state for _ in range(snippet_length)])
            else:
                decoder_states.append(state)

        return (predicted_sequence,
                loss,
                token_accuracy,
                decoder_states,
                decoder_results)
コード例 #6
0
    def predict_turn(self,
                     utterance_final_state,
                     input_hidden_states,
                     schema_states,
                     max_generation_length,
                     gold_query=None,
                     snippets=None,
                     input_sequence=None,
                     input_schema=None,
                     feed_gold_tokens=False,
                     training=False):
        """ Gets a prediction for a single turn -- calls decoder and updates loss, etc.

        TODO:  this can probably be split into two methods, one that just predicts
            and another that computes the loss.
        """
        predicted_sequence = []
        fed_sequence = []
        loss = None
        token_accuracy = 0.

        if feed_gold_tokens:
            decoder_results = self.decoder(utterance_final_state,
                                           input_hidden_states,
                                           schema_states,
                                           max_generation_length,
                                           gold_sequence=gold_query,
                                           input_sequence=input_sequence,
                                           input_schema=input_schema,
                                           snippets=snippets,
                                           dropout_amount=self.dropout)

            all_scores = [step.scores for step in decoder_results.predictions]
            all_alignments = [
                step.aligned_tokens for step in decoder_results.predictions
            ]

            # Compute the loss
            gold_sequence = gold_query

            loss = torch_utils.compute_loss(gold_sequence, all_scores,
                                            all_alignments, get_token_indices)
            if not training:
                predicted_sequence = torch_utils.get_seq_from_scores(
                    all_scores, all_alignments)
                token_accuracy = torch_utils.per_token_accuracy(
                    gold_sequence, predicted_sequence)
            fed_sequence = gold_sequence
        else:
            decoder_results = self.decoder(utterance_final_state,
                                           input_hidden_states,
                                           schema_states,
                                           max_generation_length,
                                           input_sequence=input_sequence,
                                           input_schema=input_schema,
                                           snippets=snippets,
                                           dropout_amount=self.dropout)
            predicted_sequence = decoder_results.sequence
            fed_sequence = predicted_sequence

        decoder_states = [
            pred.decoder_state for pred in decoder_results.predictions
        ]

        # fed_sequence contains EOS, which we don't need when encoding snippets.
        # also ignore the first state, as it contains the BEG encoding.

        for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
            if snippet_handler.is_snippet(token):
                snippet_length = 0
                for snippet in snippets:
                    if snippet.name == token:
                        snippet_length = len(snippet.sequence)
                        break
                assert snippet_length > 0
                decoder_states.extend([state for _ in range(snippet_length)])
            else:
                decoder_states.append(state)

        return (predicted_sequence, loss, token_accuracy, decoder_states,
                decoder_results)