Ejemplo n.º 1
0
    def __init__(self,
                 vocab: Vocabulary,
                 mydatabase: str,
                 schema_path: str,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._action_similarity = Average()

        self._valid_sql_query = SqlValidity(mydatabase=mydatabase)
        self._token_match = TokenSequenceAccuracy()
        self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path)
        self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path)
        self._coverage_loss = CoverageAttentionLossMetric()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        initializer(self)
Ejemplo n.º 2
0
 def __init__(self,
              vocab: Vocabulary,
              sentence_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              attention: Attention,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              dropout: float = 0.0) -> None:
     super(NlvrDirectSemanticParser,
           self).__init__(vocab=vocab,
                          sentence_embedder=sentence_embedder,
                          action_embedding_dim=action_embedding_dim,
                          encoder=encoder,
                          dropout=dropout)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = BasicTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         activation=Activation.by_name('tanh')(),
         add_action_bias=False,
         dropout=dropout)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
Ejemplo n.º 3
0
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 database_file='/atis/atis.db') -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)


        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                              action_embedding_dim=action_embedding_dim,
                                                              input_attention=input_attention,
                                                              add_action_bias=self._add_action_bias,
                                                              dropout=dropout,
                                                              num_layers=self._decoder_num_layers)
Ejemplo n.º 4
0
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 database_file: str,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            predict_start_type_separately=False,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        initializer(self)
Ejemplo n.º 5
0
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleState(
            [0, 1], [[], []],
            [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1])
        self.decoder_step = SimpleTransitionFunction()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()
class TestMaximumMarginalLikelihood(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleState([0, 1],
                                         [[], []],
                                         [torch.Tensor([0.0]), torch.Tensor([0.0])],
                                         [0, 1])
        self.decoder_step = SimpleTransitionFunction()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision)

        # Our loss is the negative log sum of the scores from each target sequence.  The score for
        # each sequence in our simple transition system is just `-sequence_length`.
        instance0_loss = math.log(math.exp(-3) * 3)  # all three sequences have length 3
        instance1_loss = math.log(math.exp(-2) + math.exp(-3))  # one has length 2, one has length 3
        expected_loss = -(instance0_loss + instance1_loss) / 2
        assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)
Ejemplo n.º 7
0
class TestMaximumMarginalLikelihood(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleState(
            [0, 1], [[], []],
            [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1])
        self.decoder_step = SimpleTransitionFunction()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state,
                                           self.decoder_step, self.supervision)

        # Our loss is the negative log sum of the scores from each target sequence.  The score for
        # each sequence in our simple transition system is just `-sequence_length`.
        instance0_loss = math.log(math.exp(-3) *
                                  3)  # all three sequences have length 3
        instance1_loss = math.log(
            math.exp(-2) + math.exp(-3))  # one has length 2, one has length 3
        expected_loss = -(instance0_loss + instance1_loss) / 2
        assert_almost_equal(decoded_info["loss"].data.numpy(), expected_loss)
 def __init__(self,
              vocab: Vocabulary,
              sentence_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              attention: Attention,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              dropout: float = 0.0) -> None:
     super(NlvrDirectSemanticParser, self).__init__(vocab=vocab,
                                                    sentence_embedder=sentence_embedder,
                                                    action_embedding_dim=action_embedding_dim,
                                                    encoder=encoder,
                                                    dropout=dropout)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                  action_embedding_dim=action_embedding_dim,
                                                  input_attention=attention,
                                                  num_start_types=1,
                                                  activation=Activation.by_name('tanh')(),
                                                  predict_start_type_separately=False,
                                                  add_action_bias=False,
                                                  dropout=dropout)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
Ejemplo n.º 9
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention: Attention,
              mixture_feedforward: FeedForward = None,
              add_action_bias: bool = True,
              training_beam_size: int = None,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      add_action_bias=add_action_bias,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = LinkingTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         predict_start_type_separately=True,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
Ejemplo n.º 10
0
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                            action_embedding_dim=action_embedding_dim,
                                                            input_attention=input_attention,
                                                            predict_start_type_separately=False,
                                                            add_action_bias=self._add_action_bias,
                                                            dropout=dropout)
        initializer(self)
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleState([0, 1],
                                         [[], []],
                                         [torch.Tensor([0.0]), torch.Tensor([0.0])],
                                         [0, 1])
        self.decoder_step = SimpleTransitionFunction()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention: Attention,
              mixture_feedforward: FeedForward = None,
              add_action_bias: bool = True,
              training_beam_size: int = None,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      add_action_bias=add_action_bias,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                    action_embedding_dim=action_embedding_dim,
                                                    input_attention=attention,
                                                    num_start_types=self._num_start_types,
                                                    predict_start_type_separately=True,
                                                    add_action_bias=self._add_action_bias,
                                                    mixture_feedforward=mixture_feedforward,
                                                    dropout=dropout)
class NlvrDirectSemanticParser(NlvrSemanticParser):
    """
    ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. The main difference between this parser and
    ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the TransitionFunction.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 dropout: float = 0.0) -> None:
        super(NlvrDirectSemanticParser, self).__init__(vocab=vocab,
                                                       sentence_embedder=sentence_embedder,
                                                       action_embedding_dim=action_embedding_dim,
                                                       encoder=encoder,
                                                       dropout=dropout)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                     action_embedding_dim=action_embedding_dim,
                                                     input_attention=attention,
                                                     num_start_types=1,
                                                     activation=Activation.by_name('tanh')(),
                                                     predict_start_type_separately=False,
                                                     add_action_bias=False,
                                                     dropout=dropout)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1

    @overrides
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRule]],
                identifier: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        label_strings = self._get_label_strings(labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
                                                                 initial_state,
                                                                 self._decoder_step,
                                                                 keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [best_final_states[i][0].action_history[0]]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings, worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs

    def _update_metrics(self,
                        action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrWorld]],
                        label_strings: List[List[str]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences?
        batch_size = len(worlds)
        for i in range(batch_size):
            instance_action_strings = action_strings[i]
            sequence_is_correct = [False]
            if instance_action_strings:
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                # Taking only the best sequence.
                sequence_is_correct = self._check_denotation(instance_action_strings[0],
                                                             instance_label_strings,
                                                             instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
                'consistency': self._consistency.get_metric(reset)
        }
class WikiTablesVariableFreeMml(WikiTablesVariableFreeParser):
    """
    A ``WikiTablesVariableFreeMml`` is a :class:`WikiTablesVariableFreeSemanticParser` which is trained to
    maximize the marginal likelihood of an approximate set of logical forms which give the correct
    denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type
    Constraints for Semi-Structured Tables
    <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
    by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017). The language used by
    this model is different from LambdaDCS, the one in the paper above though. This model uses the
    variable free language from ``allennlp.semparse.type_declarations.wikitables_variable_free``.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training. Passed to super class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    mixture_feedforward : ``FeedForward``, optional (default=None)
        If given, we'll use this to compute a mixture probability between global actions and linked
        actions given the hidden state at every timestep of decoding, instead of concatenating the
        logits for both (where the logits may not be compatible with each other).  Passed to
        the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.  Passed to super class.
    training_beam_size : ``int``, optional (default=None)
        If given, we will use a constrained beam search of this size during training, so that we
        use only the top ``training_beam_size`` action sequences according to the model in the MML
        computation.  If this is ``None``, we will use all of the provided action sequences in the
        MML computation.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels') -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         add_action_bias=add_action_bias,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace)
        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=True,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesVariableFreeWorld],
                actions: List[List[ProductionRuleArray]],
                target_values: List[List[str]] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=target_values,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            metadata = None
            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             target_values,
                                             metadata,
                                             outputs)

            return outputs
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesMmlSemanticParser`` is a :class:`WikiTablesSemanticParser` which is trained to
    maximize the marginal likelihood of an approximate set of logical forms which give the correct
    denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type
    Constraints for Semi-Structured Tables
    <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
    by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017).

    WORK STILL IN PROGRESS.  We'll iteratively improve it until we've reproduced the performance of
    the original parser.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training. Passed to super class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    mixture_feedforward : ``FeedForward``, optional (default=None)
        If given, we'll use this to compute a mixture probability between global actions and linked
        actions given the hidden state at every timestep of decoding, instead of concatenating the
        logits for both (where the logits may not be compatible with each other).  Passed to
        the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.  Passed to super class.
    training_beam_size : ``int``, optional (default=None)
        If given, we will use a constrained beam search of this size during training, so that we
        use only the top ``training_beam_size`` action sequences according to the model in the MML
        computation.  If this is ``None``, we will use all of the provided action sequences in the
        MML computation.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables. Passed to super class.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         add_action_bias=add_action_bias,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace,
                         tables_directory=tables_directory)
        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=True,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRule]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default = None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=example_lisp_string,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             example_lisp_string,
                                             metadata,
                                             outputs)

            return outputs
Ejemplo n.º 16
0
    def __init__(self,
                 vocab: Vocabulary,
                 decoder_beam_search: BeamSearch,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 decoder_self_attend: bool = True,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 dataset_path: str = 'spider',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:

        super().__init__(vocab)
        self.vocab = vocab

        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._rule_namespace = rule_namespace

        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._action_padding_index = -1

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = 256

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        # The linear layer transforms the initial encoding to match the dimensions specified in the article.
        self.input_linear = Linear(768, 256)

        # A single layer of the RAT encoder.
        self.rat_layer = RatEncoderLayer(d_model=256,
                                         nhead=8,
                                         dim_feedforward=1024)

        # A transformer encoder module consists of 8 RAT layers.
        self.encoder = RatEncoder(self.rat_layer, 8)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
class AtisSemanticParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    database_file: ``str``, optional (default=/atis/atis.db)
        The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need
        the file location to connect to it.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        utterance_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        decoder_beam_search: BeamSearch,
        max_decoding_steps: int,
        input_attention: Attention,
        add_action_bias: bool = True,
        training_beam_size: int = None,
        decoder_num_layers: int = 1,
        dropout: float = 0.0,
        rule_namespace: str = "rule_labels",
        database_file="/atis/atis.db",
    ) -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout,
            num_layers=self._decoder_num_layers,
        )

    @overrides
    def forward(
        self,  # type: ignore
        utterance: Dict[str, torch.LongTensor],
        world: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
        target_action_sequence: torch.LongTensor = None,
        sql_queries: List[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions,
                                                linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(
                initial_state,
                self._transition_function,
                (target_action_sequence.unsqueeze(1),
                 target_mask.unsqueeze(1)),
            )
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {"action_mapping": action_mapping}
            outputs["linking_scores"] = linking_scores
            if target_action_sequence is not None:
                outputs["loss"] = self._decoder_trainer.decode(
                    initial_state,
                    self._transition_function,
                    (target_action_sequence.unsqueeze(1),
                     target_mask.unsqueeze(1)),
                )["loss"]
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )
            outputs["best_action_sequence"] = []
            outputs["debug_info"] = []
            outputs["entities"] = []
            outputs["predicted_sql_query"] = []
            outputs["sql_queries"] = []
            outputs["utterance"] = []
            outputs["tokenized_utterance"] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs["predicted_sql_query"].append("")
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in best_action_indices
                ]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(
                        predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs["sql_queries"].append(sql_queries[i])

                outputs["utterance"].append(world[i].utterances[-1])
                outputs["tokenized_utterance"].append([
                    token.text for token in world[i].tokenized_utterances[-1]
                ])
                outputs["entities"].append(world[i].entities)
                outputs["best_action_sequence"].append(action_strings)
                outputs["predicted_sql_query"].append(
                    sqlparse.format(predicted_sql_query, reindent=True))
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_initial_state(
        self,
        utterance: Dict[str, torch.LongTensor],
        worlds: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
    ) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities,
                                                embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i].repeat(
                            self._decoder_num_layers, 1),
                        memory_cell[i].repeat(self._decoder_num_layers, 1),
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))
            else:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i],
                        memory_cell[i],
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], actions[i],
                                       linking_scores[i], entity_types[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None,
        )
        return initial_state

    @staticmethod
    def _get_type_vector(
        worlds: List[AtisWorld],
        num_entities: int,
        tensor: torch.Tensor = None
    ) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        for batch_index, world in enumerate(worlds):
            types = []
            entities = [("number", entity) if any([
                entity.startswith(numeric_nonterminal)
                for numeric_nonterminal in NUMERIC_NONTERMINALS
            ]) else ("string", entity) for entity in world.entities]

            for entity_index, entity in enumerate(entities):
                # We need numbers to be first, then strings, since our entities are going to be
                # sorted. We do a split by type and then a merge later, and it relies on this sorting.
                if entity[0] == "number":
                    entity_type = 1
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """
        return {
            "exact_match": self._exact_match.get_metric(reset),
            "denotation_acc": self._denotation_accuracy.get_metric(reset),
            "valid_sql_query": self._valid_sql_query.get_metric(reset),
            "action_similarity": self._action_similarity.get_metric(reset),
        }

    def _create_grammar_state(
        self,
        world: AtisWorld,
        possible_actions: List[ProductionRule],
        linking_scores: torch.Tensor,
        entity_types: torch.Tensor,
    ) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``AtisWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = (torch.cat(
                    global_action_tensors,
                    dim=0).to(entity_types.device).long())
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]["global"] = (
                    global_input_embeddings,
                    global_output_embeddings,
                    list(global_action_ids),
                )
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = list(linked_rules)
                entity_ids = [entity_map[entity] for entity in entities]
                entity_linking_scores = linking_scores[entity_ids]
                entity_type_tensor = entity_types[entity_ids]
                entity_type_embeddings = (
                    self._entity_type_decoder_embedding(entity_type_tensor).to(
                        entity_types.device).float())
                translated_valid_actions[key]["linked"] = (
                    entity_linking_scores,
                    entity_type_embeddings,
                    list(linked_action_ids),
                )

        return GrammarStatelet(["statement"], translated_valid_actions,
                               self.is_nonterminal)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict["action_mapping"]
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict["debug_info"]
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info["predicted_action"] = predicted_action
                considered_actions = action_debug_info["considered_actions"]
                probabilities = action_debug_info["probabilities"]
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info["considered_actions"] = considered_actions
                action_info["action_probabilities"] = probabilities
                action_info["utterance_attention"] = action_debug_info.get(
                    "question_attention", [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 18
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:
        super().__init__(vocab)
        self.vocab = vocab
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._question_embedder = question_embedder
        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._num_entity_types = 9
        self._embedding_dim = question_embedder.get_output_dim()

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._linking_params = torch.nn.Linear(16, 1)
        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)

        num_edge_types = 3
        self._gnn = GatedGraphConv(self._embedding_dim,
                                   gnn_timesteps,
                                   num_edge_types=num_edge_types,
                                   dropout=dropout)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

        self.debug_parsing = debug_parsing
Ejemplo n.º 19
0
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: Seq2SeqEncoder,
        dropout: float = 0.0,
        object_loss_multiplier: float = 0.0,
        denotation_loss_multiplier: float = 1.0,
        tokens_namespace: str = "tokens",
        rule_namespace: str = "rule_labels",
        denotation_namespace: str = "labels",
        num_parse_only_batches: int = 0,
        use_gold_program_for_eval: bool = False,
        nmn_settings: Dict = None,
    ) -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._encoder = encoder
        self._dropout = torch.nn.Dropout(p=dropout)
        self._obj_loss_multiplier = object_loss_multiplier
        self._denotation_loss_multiplier = denotation_loss_multiplier
        self._tokens_namespace = tokens_namespace
        self._rule_namespace = rule_namespace
        self._denotation_namespace = denotation_namespace
        self._num_parse_only_batches = num_parse_only_batches
        self._use_gold_program_for_eval = use_gold_program_for_eval
        self._nmn_settings = nmn_settings
        self._training_batches_so_far = 0

        self._denotation_accuracy = CategoricalAccuracy()
        self._proposal_accuracy = CategoricalAccuracy()
        # TODO(mattg): use FullSequenceMatch instead of this.
        self._program_accuracy = Average()
        self.loss = torch.nn.BCELoss()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        action_embedding_dim = 100
        self._add_action_bias = True
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._language_parameters = VisualReasoningGqaParameters(
            hidden_dim=self._encoder.get_output_dim(),
            initializer=self._encoder.encoder.model.init_bert_weights,
        )

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        # encoder_output_dim = self._lxrt_encoder.get_output_dim()
        encoder_output_dim = self._encoder.get_output_dim()
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._decoder_num_layers = 1

        self._beam_search = BeamSearch(beam_size=10)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=encoder_output_dim,
            action_embedding_dim=action_embedding_dim,
            input_attention=AdditiveAttention(vector_dim=encoder_output_dim,
                                              matrix_dim=encoder_output_dim),
            add_action_bias=self._add_action_bias,
            dropout=dropout,
            num_layers=self._decoder_num_layers,
        )
        self._language_parameters.apply(
            self._encoder.encoder.model.init_bert_weights)
        # attention.apply(self._lxrt_encoder.encoder.model.init_bert_weights)
        # self._transition_function.apply(self._lxrt_encoder.encoder.model.init_bert_weights)

        # Our language is constant across instances, so we just create one up front that we can
        # re-use to construct the `GrammarStatelet`.
        self._world = VisualReasoningGqaLanguage(None, None, None, None, None)
Ejemplo n.º 20
0
class MMParser(Model):
    # Seq2Seq encoder with Memory Network
    #
    #  Architecture
    #   1) RNN encoder for input symbols in query and input memory items (either shared or separate)
    #   2) RNN encoder for output symbols in the output memory items
    #   3) Key-value memory for embedding query items with support context
    #   4) State machine decoder for output symbols based on SQL grammar
    def __init__(self,
                 question_embedder: TextFieldEmbedder,
                 input_memory_embedder: TextFieldEmbedder,
                 output_memory_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 input_memory_encoder: Seq2VecEncoder,
                 output_memory_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 input_attention: Attention,
                 past_attention: Attention,
                 action_embedding_dim: int,
                 max_decoding_steps: int,
                 nhop: int,
                 decoding_nhop: int,
                 vocab: Vocabulary,
                 dataset_path: str = 'dataset',
                 parse_sql_on_decoding: bool = True,
                 training_beam_size: int = None,
                 add_action_bias: bool = True,
                 decoder_self_attend: bool = True,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab)

        self.question_embedder = question_embedder
        self._input_mm_embedder = input_memory_embedder
        self._output_mm_embedder = output_memory_embedder
        self._question_encoder = question_encoder
        self._input_mm_encoder = TimeDistributed(input_memory_encoder)
        self._output_mm_encoder = TimeDistributed(output_memory_encoder)

        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._rule_namespace = rule_namespace
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._input_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._num_entity_types = 9
        self._entity_type_decoder_input_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)
        self._entity_type_decoder_output_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types,
            (int)(question_encoder.get_output_dim() / 2))

        self._decoder_num_layers = decoder_num_layers
        self._action_embedding_dim = action_embedding_dim

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(question_encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        if self._self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                decoding_nhop=decoding_nhop,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._mm_attn = MemAttn(question_encoder.get_output_dim(), nhop)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

    @overrides
    def forward(
            self,
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        device = utterance['tokens'].device
        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(
                initial_state, self._transition_function,
                (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            # initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
        return outputs

    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]]) -> GrammarBasedState:

        utterance_mask = util.get_text_field_mask(utterance).float()
        embedded_utterance = self.question_embedder(utterance)
        batch_size, _, _ = embedded_utterance.size()
        encoder_outputs = self._dropout(
            self._question_encoder(embedded_utterance, utterance_mask))

        schema_text = schema['text']
        input_mm_schema = self._input_mm_embedder(schema_text,
                                                  num_wrapping_dims=1)
        output_mm_schema = self._output_mm_embedder(schema_text,
                                                    num_wrapping_dims=1)
        batch_size, num_entities, num_entity_tokens, _ = input_mm_schema.size()
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        # TODO
        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, input_mm_schema.device)
        # (batch_size, num_entities, embedding_dim)
        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # (batch_size, num_entities, embedding_dim)
        # An entity memory-representation is concatenated with two parts:
        # 1. Entity tokens embedding
        # 2. Entity type embedding
        K = torch.cat([
            self._input_mm_encoder(input_mm_schema, schema_mask),
            entity_type_embeddings
        ],
                      dim=2)
        V = torch.cat([
            self._output_mm_encoder(output_mm_schema, schema_mask),
            entity_type_embeddings
        ],
                      dim=2)
        encoder_output_dim = self._question_encoder.get_output_dim()

        # Encodes utterance in the context of the schema, which is stored in external memory
        encoder_outputs_with_context, attn_weights = self._mm_attn(
            encoder_outputs, K, V)
        attn_weights = attn_weights.transpose(1, 2)
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs_with_context, utterance_mask,
            self._question_encoder.is_bidirectional())

        max_entities_relevance = attn_weights.max(dim=2)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        if self._self_attend:
            entities_ff = self._ent2ent_ff(entity_type_embeddings *
                                           entities_relevance)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        # RnnStatelet is using to keep track of the internal state of a decoder RNN:
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], valid_actions[i],
                                       attn_weights[i],
                                       linked_actions_linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(valid_actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=valid_actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state

    @staticmethod
    def _get_type_vector(worlds: List[SpiderWorld], num_entities: int,
                         device) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        column_type_ids = [
            'boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time'
        ]

        for batch_index, world in enumerate(worlds):
            types = []

            for entity_index, entity in enumerate(
                    world.db_context.knowledge_graph.entities):
                parts = entity.split(':')
                entity_main_type = parts[0]
                if entity_main_type == 'column':
                    column_type = parts[1]
                    entity_type = column_type_ids.index(column_type)
                elif entity_main_type == 'string':
                    # cell value
                    entity_type = len(column_type_ids)
                elif entity_main_type == 'table':
                    entity_type = len(column_type_ids) + 1
                else:
                    raise (Exception("Unkown entity"))
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return torch.tensor(batch_types, dtype=torch.long,
                            device=device), entity_types

    def _create_grammar_state(self, world: SpiderWorld,
                              possible_actions: List[ProductionRule],
                              attn_weights: torch.Tensor,
                              linked_actions_linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names
        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(
                    global_action_tensors,
                    dim=0).to(global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_input_action_embeddings = self._input_action_embedder(
                    global_action_tensor)
                global_output_action_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_input_action_embeddings,
                    list(global_action_ids), global_output_action_embeddings)

            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [
                    rule.split(' -> ')[1].strip('[]\"')
                    for rule in linked_rules
                ]
                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = attn_weights[entity_ids]
                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[
                        entity_ids]
                # if not self._decoder_use_graph_entities:
                entity_type_tensor = entity_types[entity_ids]
                entity_type_input_embeddings = (
                    self._entity_type_decoder_input_embedding(
                        entity_type_tensor).to(entity_types.device).float())
                entity_type_output_embeddings = (
                    self._entity_type_decoder_output_embedding(
                        entity_type_tensor).to(entity_types.device).float())
                #entity_type_input_embeddings = None
                #entity_type_output_embeddings = None
                # else:
                #     entity_type_embeddings = entity_graph_encoding.index_select(
                #         dim=0,
                #         index=torch.tensor(entity_ids, device=entity_graph_encoding.device)
                #     )

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_input_embeddings,
                        list(linked_action_ids), entity_action_linking_scores,
                        entity_type_output_embeddings)
                else:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_input_embeddings,
                        list(linked_action_ids), entity_type_output_embeddings)

        return GrammarStatelet(['statement'], translated_valid_actions,
                               self.is_nonterminal)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()

    @staticmethod
    def _query_difficulty(targets: torch.LongTensor, action_mapping,
                          batch_index):
        number_tables = len([
            action_mapping[(batch_index, int(a))] for a in targets
            if a >= 0 and action_mapping[(batch_index,
                                          int(a))].startswith('table_name')
        ])
        return number_tables > 1

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            '_match/exact_match': self._exact_match.get_metric(reset),
            'sql_match': self._sql_evaluator_match.get_metric(reset),
            '_others/action_similarity':
            self._action_similarity.get_metric(reset),
            '_match/match_single': self._acc_single.get_metric(reset),
            '_match/match_hard': self._acc_multi.get_metric(reset),
            'beam_hit': self._beam_hit.get_metric(reset)
        }

    def find_shortest_path(self, start, end, graph):
        stack = [[start, []]]
        visited = set()
        while len(stack) > 0:
            ele, history = stack.pop()
            if ele == end:
                return history
            for node in graph[ele]:
                if node[0] not in visited:
                    stack.append((node[0], history + [(node[0], node[1])]))
                    visited.add(node[0])
        # print("table {} table {}".format(start,end))

    def _add_from_clause(self, origin_query, world: SpiderWorld):
        predicted_sql_query_tokens = origin_query.split(" ")
        # print("predicted_sql_query_tokens:{}".format(predicted_sql_query_tokens))

        select_indices = [
            i for i, x in enumerate(predicted_sql_query_tokens)
            if x == "select"
        ]
        select_indices.append(len(predicted_sql_query_tokens))
        # From bottom to top
        select_indices.reverse()
        dbs_json_blob = json.load(open(world.db_context.tables_file, "r"))
        graph = defaultdict(list)
        table_list = []
        dbtable = {}
        for table in dbs_json_blob:
            if world.db_id == table['db_id']:
                dbtable = table
                for acol, bcol in table["foreign_keys"]:
                    t1 = table["column_names"][acol][0]
                    t2 = table["column_names"][bcol][0]
                    graph[t1].append((t2, (acol, bcol)))
                    graph[t2].append((t1, (bcol, acol)))
                table_list = [table for table in table["table_names_original"]]
        # print("table_list:{}".format(table_list))

        end_idx = select_indices[0]
        for index in select_indices[1:]:
            table_alias_dict = {}
            idx = 1

            start_idx = index
            tables = set([
                token.split(".")[0]
                for token in predicted_sql_query_tokens[start_idx:end_idx]
                if '.' in token
            ])
            # print(tables)
            candidate_tables: List[int] = []
            for table in tables:
                for i, table1 in enumerate(table_list):
                    if table1.lower() == table:
                        candidate_tables.append(i)
                        break
            # print("\ncandidate_tables:{}".format(candidate_tables))
            ret = ""
            flag_only_sel_count = False
            if len(candidate_tables) > 1:
                start = candidate_tables[0]
                table_alias_dict[start] = idx
                idx += 1
                ret = "from {}".format(dbtable["table_names_original"][start])
                try:
                    for end in candidate_tables[1:]:
                        if end in table_alias_dict:
                            continue
                        path = self.find_shortest_path(start, end, graph)
                        # print("got path = {}".format(path))
                        prev_table = start
                        if not path:
                            table_alias_dict[end] = idx
                            idx += 1
                            ret = "{} join {}".format(
                                ret, dbtable["table_names_original"][end])
                            continue
                        for node, (acol, bcol) in path:
                            if node in table_alias_dict:
                                prev_table = node
                                continue
                            table_alias_dict[node] = idx
                            idx += 1
                            # print("test every slot:")
                            # print("table:{}, dbtable:{}".format(table, dbtable))
                            # print(dbtable["table_names_original"][node])
                            # print(dbtable["table_names_original"][prev_table])
                            # print(dbtable["column_names_original"][acol][1])
                            # print(dbtable["table_names_original"][node])
                            # print(dbtable["column_names_original"][bcol][1])
                            ret = "{} join {} on {}.{} = {}.{}".format(
                                ret, dbtable["table_names_original"][node],
                                dbtable["table_names_original"][prev_table],
                                dbtable["column_names_original"][acol][1],
                                dbtable["table_names_original"][node],
                                dbtable["column_names_original"][bcol][1])
                            prev_table = node

                except:
                    print(
                        "\n!!Exception in spider_parser.py : line 924!! \npredicted_sql_query_tokens:{}"
                        .format(predicted_sql_query_tokens))
                # print(ret)
            # If all the columns from one table, generate FROM Clause directly
            elif len(candidate_tables) == 1:
                ret = "from {}".format(tables.pop())
                # print("\nret:{}".format(ret))
            else:
                ret = 'from'
                flag_only_sel_count = True

            if not flag_only_sel_count:
                flag = False
                index = start_idx + len(
                    predicted_sql_query_tokens[start_idx:end_idx])
                brace_count = 0
                for ii, token in enumerate(
                        predicted_sql_query_tokens[start_idx:end_idx]):
                    if token == "(":
                        brace_count += 1
                    if token == ")":
                        if brace_count == 0:
                            index = ii + start_idx
                            predicted_sql_query_tokens = predicted_sql_query_tokens[:index] + [ret] + \
                                                         predicted_sql_query_tokens[index:]
                            # print(predicted_sql_query_tokens)
                            flag = True
                            break
                        else:
                            brace_count -= 1
                    if token == "where" or token == "group" or token == "order":
                        index = ii + start_idx
                        predicted_sql_query_tokens = predicted_sql_query_tokens[:index] + [ret] + \
                                                     predicted_sql_query_tokens[index:]
                        flag = True
                        # print(predicted_sql_query_tokens)
                        break
                if not flag:
                    predicted_sql_query_tokens = predicted_sql_query_tokens[:index] + [
                        ret
                    ]
                    # print("\npredicted_sql_query_tokens:{}".format(' '.join([token for token in predicted_sql_query_tokens])))
            else:
                for ii, token in enumerate(
                        predicted_sql_query_tokens[start_idx:end_idx]):
                    if token == "from_count":
                        predicted_sql_query_tokens = predicted_sql_query_tokens[:ii] + [ret] + \
                                                     predicted_sql_query_tokens[ii + 1:]
                        # print("\npredicted_sql_query:{}".format(' '.join([token for token in predicted_sql_query_tokens])))
                        break
            end_idx = start_idx
            # print("predicted_sql_query_tokens:{}".format(predicted_sql_query_tokens))
        return ' '.join([
            '*' if '.*' in token else token
            for token in predicted_sql_query_tokens
        ])

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRuleArray]],
                                    best_final_states: Mapping[
                                        int, Sequence[GrammarBasedState]],
                                    world: List[SpiderWorld],
                                    target_list: List[List[str]],
                                    outputs: Dict[str, Any]) -> None:
        batch_size = len(actions)

        outputs['predicted_sql_query'] = []

        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]

        for i in range(batch_size):
            # gold sql exactly as given
            original_gold_sql_query = ' '.join(
                world[i].get_query_without_table_hints())

            if i not in best_final_states:
                self._exact_match(0)
                self._action_similarity(0)
                self._sql_evaluator_match(0)
                self._acc_multi(0)
                self._acc_single(0)
                outputs['predicted_sql_query'].append('')
                continue

            best_action_indices = best_final_states[i][0].action_history[0]

            action_strings = [
                action_mapping[(i, action_index)]
                for action_index in best_action_indices
            ]
            predicted_sql_query = action_sequence_to_sql(action_strings,
                                                         add_table_names=True)
            # print ("predicted_sql_query:{}".format(predicted_sql_query))

            predicted_sql_query = self._add_from_clause(
                predicted_sql_query, world[i])
            # predicted_sql_query = ' '.join([token for token in predicted_sql_query_tokens])
            # print("predicted_sql_query:{}".format(predicted_sql_query))
            outputs['predicted_sql_query'].append(
                sqlparse.format(predicted_sql_query, reindent=False))

            if target_list is not None:
                targets = target_list[i].data
            target_available = target_list is not None and targets[0] > -1

            if target_available:
                sequence_in_targets = self._action_history_match(
                    best_action_indices, targets)
                self._exact_match(sequence_in_targets)

                sql_evaluator_match = self._evaluate_func(
                    original_gold_sql_query, predicted_sql_query,
                    world[i].db_id)
                self._sql_evaluator_match(sql_evaluator_match)

                similarity = difflib.SequenceMatcher(None, best_action_indices,
                                                     targets)
                self._action_similarity(similarity.ratio())

                difficulty = self._query_difficulty(targets, action_mapping, i)
                if difficulty:
                    self._acc_multi(sql_evaluator_match)
                else:
                    self._acc_single(sql_evaluator_match)

            beam_hit = False
            for pos, final_state in enumerate(best_final_states[i]):
                action_indices = final_state.action_history[0]
                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in action_indices
                ]
                candidate_sql_query = action_sequence_to_sql(
                    action_strings, add_table_names=True)

                if target_available:
                    correct = self._evaluate_func(original_gold_sql_query,
                                                  candidate_sql_query,
                                                  world[i].db_id)
                    if correct:
                        beam_hit = True
                    self._beam_hit(beam_hit)
Ejemplo n.º 21
0
    def __init__(
            self,
            vocab: Vocabulary,
            question_embedder: TextFieldEmbedder,
            action_embedding_dim: int,
            encoder: Seq2SeqEncoder,
            decoder_beam_search: BeamSearch,
            max_decoding_steps: int,
            attention: Attention,
            mixture_feedforward: FeedForward = None,
            add_action_bias: bool = True,
            dropout: float = 0.0,
            num_linking_features: int = 0,
            num_entity_bits: int = 0,
            entity_bits_output: bool = True,
            use_entities: bool = False,
            denotation_only: bool = False,
            # Deprecated parameter to load older models
            entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
            entity_similarity_mode: str = "dot_product",
            rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1  # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(
                self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(
                self._encoder_output_dim, self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,
                                            embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(
            encoder_output_dim=self._encoder_output_dim,
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            num_start_types=self._num_start_types,
            predict_start_type_separately=False,
            add_action_bias=self._add_action_bias,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout)
Ejemplo n.º 22
0
    def __init__(self,
                 question_embedder: TextFieldEmbedder,
                 input_memory_embedder: TextFieldEmbedder,
                 output_memory_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 input_memory_encoder: Seq2VecEncoder,
                 output_memory_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 input_attention: Attention,
                 past_attention: Attention,
                 action_embedding_dim: int,
                 max_decoding_steps: int,
                 nhop: int,
                 decoding_nhop: int,
                 vocab: Vocabulary,
                 dataset_path: str = 'dataset',
                 parse_sql_on_decoding: bool = True,
                 training_beam_size: int = None,
                 add_action_bias: bool = True,
                 decoder_self_attend: bool = True,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab)

        self.question_embedder = question_embedder
        self._input_mm_embedder = input_memory_embedder
        self._output_mm_embedder = output_memory_embedder
        self._question_encoder = question_encoder
        self._input_mm_encoder = TimeDistributed(input_memory_encoder)
        self._output_mm_encoder = TimeDistributed(output_memory_encoder)

        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._rule_namespace = rule_namespace
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._input_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._num_entity_types = 9
        self._entity_type_decoder_input_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)
        self._entity_type_decoder_output_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types,
            (int)(question_encoder.get_output_dim() / 2))

        self._decoder_num_layers = decoder_num_layers
        self._action_embedding_dim = action_embedding_dim

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(question_encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        if self._self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                decoding_nhop=decoding_nhop,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._mm_attn = MemAttn(question_encoder.get_output_dim(), nhop)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)
Ejemplo n.º 23
0
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 database_file='/atis/atis.db') -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)


        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                              action_embedding_dim=action_embedding_dim,
                                                              input_attention=input_attention,
                                                              predict_start_type_separately=False,
                                                              add_action_bias=self._add_action_bias,
                                                              dropout=dropout,
                                                              num_layers=self._decoder_num_layers)
Ejemplo n.º 24
0
class AtisSemanticParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    database_file: ``str``, optional (default=/atis/atis.db)
        The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need
        the file location to connect to it.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 database_file='/atis/atis.db') -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)


        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                              action_embedding_dim=action_embedding_dim,
                                                              input_attention=input_attention,
                                                              predict_start_type_separately=False,
                                                              add_action_bias=self._add_action_bias,
                                                              dropout=dropout,
                                                              num_layers=self._decoder_num_layers)

    @overrides
    def forward(self,  # type: ignore
                utterance: Dict[str, torch.LongTensor],
                world: List[AtisWorld],
                actions: List[List[ProductionRule]],
                linking_scores: torch.Tensor,
                target_action_sequence: torch.LongTensor = None,
                sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(initial_state,
                                                self._transition_function,
                                                (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            outputs['linking_scores'] = linking_scores
            if target_action_sequence is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._transition_function,
                                                               (target_action_sequence.unsqueeze(1),
                                                                target_mask.unsqueeze(1)))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            outputs['utterance'] = []
            outputs['tokenized_utterance'] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[(i, action_index)]
                                  for action_index in best_action_indices]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs['sql_queries'].append(sql_queries[i])

                outputs['utterance'].append(world[i].utterances[-1])
                outputs['tokenized_utterance'].append([token.text
                                                       for token in world[i].tokenized_utterances[-1]])
                outputs['entities'].append(world[i].entities)
                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRule]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             utterance_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1),
                                                     memory_cell[i].repeat(self._decoder_num_layers, 1),
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))
            else:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                     memory_cell[i],
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))


        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state

    @staticmethod
    def _get_type_vector(worlds: List[AtisWorld],
                         num_entities: int,
                         tensor: torch.Tensor = None) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        for batch_index, world in enumerate(worlds):
            types = []
            entities = [('number', entity)
                        if any([entity.startswith(numeric_nonterminal)
                                for numeric_nonterminal in NUMERIC_NONTERMINALS])
                        else ('string', entity)
                        for entity in world.entities]

            for entity_index, entity in enumerate(entities):
                # We need numbers to be first, then strings, since our entities are going to be
                # sorted. We do a split by type and then a merge later, and it relies on this sorting.
                if entity[0] == 'number':
                    entity_type = 1
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """
        return {
                'exact_match': self._exact_match.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'valid_sql_query': self._valid_sql_query.get_metric(reset),
                'action_similarity': self._action_similarity.get_metric(reset)
                }

    def _create_grammar_state(self,
                              world: AtisWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``AtisWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).to(entity_types.device).long()
                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)
                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = linked_rules
                entity_ids = [entity_map[entity] for entity in entities]
                entity_linking_scores = linking_scores[entity_ids]
                entity_type_tensor = entity_types[entity_ids]
                entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)
                                          .to(entity_types.device)
                                          .float())
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 25
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 graph_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 enable_gating: bool = False,
                 ablation_mode: str = None,
                 gnn: bool = True,
                 graph_loss_lambda: float = 0.5,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 pruning_gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = False,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 log_path: str = '',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab, encoder, entity_encoder, question_embedder, gnn_timesteps, dropout, rule_namespace)

        self.enable_gating = enable_gating
        self.ablation_mode = ablation_mode
        self._log_path = log_path
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias

        self._parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        self._graph_loss_lambda = graph_loss_lambda

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._embedding_projector = torch.nn.Linear(question_embedder.get_output_dim(), self._embedding_dim, bias=False)
        self._bert_embedding_dim = question_embedder.get_output_dim()
        encoder_output_dim = self._encoder.get_output_dim() + self._embedding_dim

        self._neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))

        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._graph_pruning = GraphPruning(3, self._embedding_dim, encoder.get_output_dim(), dropout,
                                           timesteps=pruning_gnn_timesteps)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                                action_embedding_dim=action_embedding_dim,
                                                                                input_attention=input_attention,
                                                                                past_attention=past_attention,
                                                                                enable_gating=self.enable_gating,
                                                                                ablation_mode=self.ablation_mode,
                                                                                predict_start_type_separately=False,
                                                                                add_action_bias=self._add_action_bias,
                                                                                dropout=dropout,
                                                                                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                  action_embedding_dim=action_embedding_dim,
                                                                  input_attention=input_attention,
                                                                  predict_start_type_separately=False,
                                                                  add_action_bias=self._add_action_bias,
                                                                  dropout=dropout,
                                                                  num_layers=self._decoder_num_layers)

        if self.enable_gating:
            self._graph_attention = graph_attention
        else:
            self._graph_attention = DotProductAttention()

        self._embedding_sim_attn = CosineMatrixAttention()

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(evaluate,
                                      db_dir=os.path.join(dataset_path, 'database'),
                                      table=os.path.join(dataset_path, 'tables.json'),
                                      check_valid=False)
Ejemplo n.º 26
0
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
                                                          self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=False,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)
Ejemplo n.º 27
0
class QuarelSemanticParser(Model):
    """
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    """
    def __init__(
            self,
            vocab: Vocabulary,
            question_embedder: TextFieldEmbedder,
            action_embedding_dim: int,
            encoder: Seq2SeqEncoder,
            decoder_beam_search: BeamSearch,
            max_decoding_steps: int,
            attention: Attention,
            mixture_feedforward: FeedForward = None,
            add_action_bias: bool = True,
            dropout: float = 0.0,
            num_linking_features: int = 0,
            num_entity_bits: int = 0,
            entity_bits_output: bool = True,
            use_entities: bool = False,
            denotation_only: bool = False,
            # Deprecated parameter to load older models
            entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
            entity_similarity_mode: str = "dot_product",
            rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1  # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(
                self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(
                self._encoder_output_dim, self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,
                                            embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(
            encoder_output_dim=self._encoder_output_dim,
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            num_start_types=self._num_start_types,
            predict_start_type_separately=False,
            add_action_bias=self._add_action_bias,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout)

    @overrides
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[QuarelWorld],
            actions: List[List[ProductionRule]],
            entity_bits: torch.Tensor = None,
            denotation_target: torch.Tensor = None,
            target_action_sequences: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        entity_bits : ``torch.Tensor``, optional (default=None)
            Tensor encoding bits for the world entities.
        denotation_target : ``torch.Tensor``, optional (default=None)
            If model's field ``denotation_only`` is True, this is the tensor target denotation.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : List[Dict[str, Any]], optional (default=None).
            A dictionary of metadata for each batch element which has keys:
                question_tokens : ``List[str]``, optional.
                    The original string tokens in the question.
                world_extractions : ``nltk.Tree``, optional.
                    Extracted worlds from the question.
                answer_index : ``List[str]``, optional.
                    Index of the correct answer.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(
                    embedded_table.view(batch_size,
                                        num_entities * num_entity_tokens,
                                        self._embedding_dim),
                    torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(
                    -1, num_entities * num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size,
                                          num_entities * num_entity_tokens,
                                          self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(
                    batch_size,
                    num_question_tokens * num_entities * num_entity_tokens,
                    self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(
                    product)
                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(
                    linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(
                world, linking_scores.transpose(1, 2), question_mask,
                entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(
                final_encoder_output)
            loss = torch.nn.functional.cross_entropy(
                denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_question, encoder_output_list,
                            question_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=None,
            debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get(
                        'question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get(
                        'world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                        0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in best_action_indices
                    ]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(
                            action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(
                            predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(
                        best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs

    @staticmethod
    def _get_type_vector(
            worlds: List[QuarelWorld], num_entities: int,
            tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(
            self, worlds: List[QuarelWorld], linking_scores: torch.FloatTensor,
            question_mask: torch.LongTensor,
            entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int,
                          target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0 / self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        if self._denotation_only:
            metrics = {
                'denotation_acc':
                self._denotation_accuracy_cat.get_metric(reset)
            }
        else:
            metrics = {
                'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'lf_percent': self._has_logical_form.get_metric(reset),
            }
        return metrics

    def _create_grammar_state(self, world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                    global_action_tensor)
                global_input_embeddings = torch.cat(
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,
                list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))

        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               type_declaration.is_nonterminal)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get(
                    'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 28
0
class LCQuADMmlSemanticParser(LCQuADSemanticParser):
    """
    ``LCQuADMmlSemanticParser`` is an ``LCQuADSemanticParser`` that solves the problem of lack of
    logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. This parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the TransitionFunction.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    training_beam_size : ``int``, optional (default=None)
        If given, we will use a constrained beam search of this size during training, so that we
        use only the top ``training_beam_size`` action sequences according to the model in the MML
        computation.  If this is ``None``, we will use all of the provided action sequences in the
        MML computation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 decoder_beam_search: BeamSearch,
                 val_outputs,
                 max_decoding_steps: int,
                 training_beam_size: int = None,
                 dropout: float = 0.0) -> None:

        super().__init__(vocab=vocab,
                         sentence_embedder=sentence_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         dropout=dropout)
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._decoder_step = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            activation=Activation.by_name('relu')(),
            add_action_bias=False,
            dropout=dropout)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1
        self.val_outputs = val_outputs

    @overrides
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            question_predicates,
            # labelled_results,
            world: List[LCQuADLanguage],
            actions: List[List[ProductionRule]],
            question_entities=None,
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None,
            logical_forms=None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihood over a set of approximate logical forms.
        """
        assert target_action_sequences is not None

        batch_size = question['tokens'].size()[0]
        # Remove the trailing dimension (from ListField[ListField[IndexField]]).
        # assert target_action_sequences.dim() == 3
        target_action_sequences = target_action_sequences.squeeze(-1)
        target_mask = target_action_sequences != self._action_padding_index

        # if self._kg_embedder:
        #     embedded_entities = self._kg_embedder(question_entities, input_type="entity")
        #     embedded_type_entities = self._kg_embedder(question_type_entities, input_type="entity")
        #     embedded_predicates = self._kg_embedder(question_predicates, input_type="predicate")

        initial_rnn_state = self._get_initial_rnn_state(question)
        initial_score_list = [
            next(iter(question.values())).new_zeros(1, dtype=torch.float)
            for _ in range(batch_size)
        ]

        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_statelet = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_statelet,
            possible_actions=actions)

        outputs = self._decoder_trainer.decode(
            initial_state, self._decoder_step,
            (target_action_sequences, target_mask))

        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [
                        best_final_states[i][0].action_history[0]
                    ]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(
                actions, best_action_sequences)

            # self._update_metrics(action_strings=batch_action_strings,
            #                     worlds=world,
            #                     labelled_results=labelled_results)

            debug_infos = []
            for i in range(batch_size):
                debug_infos.append(best_final_states[i][0].debug_info[0])

            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]

            self._update_seq_metrics(action_strings=batch_action_strings,
                                     worlds=world,
                                     gold_logical_forms=logical_forms,
                                     train=self.training)

            outputs["predicted queries"] = batch_action_strings

            best_actions = batch_action_strings
            batch_action_info = []
            for batch_index, (predicted_actions, debug_info) in enumerate(
                    zip(best_actions, debug_infos)):
                instance_action_info = []
                for predicted_action, action_debug_info in zip(
                        predicted_actions[0], debug_info):
                    action_info = {}
                    action_info['predicted_action'] = predicted_action
                    considered_actions = action_debug_info[
                        'considered_actions']
                    probabilities = action_debug_info['probabilities']
                    actions = []
                    for action, probability in zip(considered_actions,
                                                   probabilities):
                        if action != -1:
                            actions.append(
                                (action_mapping[(batch_index, action)],
                                 probability))
                    actions.sort()
                    considered_actions, probabilities = zip(*actions)
                    action_info['considered_actions'] = considered_actions
                    action_info['action_probabilities'] = probabilities
                    action_info['question_attention'] = action_debug_info.get(
                        'question_attention', [])
                    instance_action_info.append(action_info)
                batch_action_info.append(instance_action_info)
            outputs["predicted_actions"] = batch_action_info

        return outputs

    def _update_seq_metrics(self,
                            action_strings: List[List[List[str]]],
                            worlds: List[LCQuADLanguage],
                            gold_logical_forms: List[str],
                            train=False):
        batch_size = len(worlds)
        for i in range(batch_size):
            generated_lf = worlds[i].action_sequence_to_logical_form(
                action_strings[i][0]) if action_strings else ''
            gold_lf = gold_logical_forms[i][0]
            if not train:
                self.val_outputs.write(
                    json.dumps({
                        "gold": gold_lf,
                        "generated": generated_lf
                    }))
                self.val_outputs.write("\n")
                self.val_outputs.write("\n")
                self.val_outputs.flush()

            self._compute_instance_seq_metric(generated_lf, gold_lf)

        for metric_name, metric in self._metrics.items():
            overall_accuracy_metric = self._metrics[OVERALL_ACC_SCORE]
            if metric_name == "accuracy":
                accuracy = metric.get_metric()
                overall_accuracy_metric(accuracy)

    # noinspection PyTypeChecker
    def _update_metrics(
            self, action_strings: List[List[List[str]]],
            worlds: List[LCQuADLanguage],
            labelled_results: List[Union[bool, int, Set[Entity]]]) -> None:

        batch_size = len(worlds)

        def retrieve_results(action_strings, world):
            action_sequence = action_strings[0] if action_strings else []
            return world.execute_action_sequence(action_sequence)

        # retrieved_results = Parallel(n_jobs=10)(delayed(retrieve_results)(actions, world) for actions, world
        #                                                 in zip(action_strings[:batch_size], worlds[:batch_size]))
        # retrieved_results = list(map(retrieve_results, action_strings[:batch_size], worlds[:batch_size]))

        for i in range(batch_size):
            self._compute_instance_metrics(
                action_strings[i][0] if action_strings else [],
                labelled_results[i], worlds[i])

        # Update overall score.
        for metric_name, metric in self._metrics.items():
            overall_score_metric = self._metrics[OVERALL_SCORE]
            overall_accuracy_metric = self._metrics[OVERALL_ACC_SCORE]
            if metric_name == "accuracy":
                accuracy = metric.get_metric()
                # if metric_name.replace("accuracy", "") not in self.retrieval_question_types:
                #     overall_score_metric(accuracy)
                overall_accuracy_metric(accuracy)

            elif metric_name == "precision":
                precision = metric.get_metric()
                recall = self._metrics["recall"].get_metric()
                f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
                overall_score_metric(f1)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = OrderedDict()
        for key, metric in self._metrics.items():
            # tensorboard_key = key_to_tensorboard_key(key)
            # always write something (None) to dict to preserve metric ordering
            metrics[key] = metric.get_metric(
                reset)  # if not self.training else -1
        return metrics
Ejemplo n.º 29
0
class SpiderParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:
        super().__init__(vocab)
        self.vocab = vocab
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._question_embedder = question_embedder
        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._num_entity_types = 9
        self._embedding_dim = question_embedder.get_output_dim()

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._linking_params = torch.nn.Linear(16, 1)
        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)

        num_edge_types = 3
        self._gnn = GatedGraphConv(self._embedding_dim,
                                   gnn_timesteps,
                                   num_edge_types=num_edge_types,
                                   dropout=dropout)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

        self.debug_parsing = debug_parsing

    @overrides
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        batch_size = len(world)
        device = utterance['tokens'].device

        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(
                initial_state, self._transition_function,
                (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
            return outputs

    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        """KAIMARY"""
        # TextFieldEmbedder needs a "token" key in the Dict
        """
        embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim])
        schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens])
        embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim])
        entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim])
        """
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)
        """KAIMARY"""
        # Variable: linking_scores
        # The entitiy linking score s(e, i) in the Krishnamurthy 2017
        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores
        """KAIMARY"""
        # linking_probabilities
        # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type
        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            """KAIMARY"""
            # Seq2VecEncoder get the hidden state of the last step as the unique output
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())
            """KAIMARY"""
            # Variable: entity_embedding
            # Rv in B Bogin 2019
            # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)
        """KAIMARY"""
        # Variable: link_embedding
        # Li in B Bogin 2019
        # Is an average of entity vectors weighted by the resulting distribution
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        """KAIMARY"""
        # Variable: encoder_input
        # [Wi, Li] in B Bogin 2019
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))
        """KAIMARY"""
        # Variable: max_entities_relevance
        # ρv = maxi plink(v | xi) in B Bogin 2019
        # Is the maximum probability of v for any word xi
        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        """KAIMARY"""
        # entity_type_embeddings ???
        # Variable: graph_initial_embedding
        # hv(0) in B Bogin 2019
        # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN
        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            """KAIMARY"""
            # Variable: entities_graph_encoding
            # φv in  B Bogin 2019
            # Is the final representation of each schema item after L steps
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            """KAIMARY"""
            # Variable: graph_link_embedding
            # Lφ,i in B Bogin 2019
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state

    @staticmethod
    def _get_neighbor_indices(worlds: List[SpiderWorld], num_entities: int,
                              device: torch.device) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[SpiderWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch
        have no neighbors, None will be returned.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.db_context.knowledge_graph.entities:
                if len(world.db_context.knowledge_graph.neighbors[entity]
                       ) > num_neighbors:
                    num_neighbors = len(
                        world.db_context.knowledge_graph.neighbors[entity])

        batch_neighbors = []
        no_entities_have_neighbors = True
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.db_context.knowledge_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.db_context.knowledge_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [
                    entity2index[n] for n in entity2neighbors[entity]
                ]
                if entity_neighbors:
                    no_entities_have_neighbors = False
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors,
                                                num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(
                neighbor_indexes, num_entities, lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        # It is possible that none of the entities has any neighbors, since our definition of the
        # knowledge graph allows it when no entities or numbers were extracted from the question.
        if no_entities_have_neighbors:
            return None
        return torch.tensor(batch_neighbors, device=device, dtype=torch.long)

    def _get_schema_graph_encoding(
        self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        max_num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        batch_size = initial_graph_embeddings.size(0)

        graph_data_list = []

        for batch_index, world in enumerate(worlds):
            x = initial_graph_embeddings[batch_index]

            adj_list = self._get_graph_adj_lists(
                initial_graph_embeddings.device, world,
                initial_graph_embeddings.size(1) - 1)
            graph_data = Data(x)
            for i, l in enumerate(adj_list):
                graph_data[f'edge_index_{i}'] = l
            graph_data_list.append(graph_data)

        batch = Batch.from_data_list(graph_data_list)

        gnn_output = self._gnn(batch.x, [
            batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)
        ])

        num_nodes = max_num_entities
        gnn_output = gnn_output.view(batch_size, num_nodes, -1)
        # entities_encodings = gnn_output
        entities_encodings = gnn_output[:, :max_num_entities]
        # global_node_encodings = gnn_output[:, max_num_entities]

        return entities_encodings

    @staticmethod
    def _get_graph_adj_lists(device,
                             world,
                             global_entity_id,
                             global_node=False):
        entity_mapping = {}
        for i, entity in enumerate(world.db_context.knowledge_graph.entities):
            entity_mapping[entity] = i
        entity_mapping['_global_'] = global_entity_id
        adj_list_own = []  # column--table
        adj_list_link = []  # table->table / foreign->primary
        adj_list_linked = []  # table<-table / foreign<-primary
        adj_list_global = []  # node->global

        # TODO: Prepare in advance?
        for key, neighbors in world.db_context.knowledge_graph.neighbors.items(
        ):
            idx_source = entity_mapping[key]
            for n_key in neighbors:
                idx_target = entity_mapping[n_key]
                if n_key.startswith("table") or key.startswith("table"):
                    adj_list_own.append((idx_source, idx_target))
                elif n_key.startswith("string") or key.startswith("string"):
                    adj_list_own.append((idx_source, idx_target))
                elif key.startswith("column:foreign"):
                    adj_list_link.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_link.append((idx_source_table, idx_target_table))
                elif n_key.startswith("column:foreign"):
                    adj_list_linked.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_linked.append(
                        (idx_source_table, idx_target_table))
                else:
                    assert False

            adj_list_global.append((idx_source, entity_mapping['_global_']))

        all_adj_types = [adj_list_own, adj_list_link, adj_list_linked]

        if global_node:
            all_adj_types.append(adj_list_global)

        return [
            torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1)
            if l else torch.tensor(l, device=device, dtype=torch.long)
            for l in all_adj_types
        ]

    def _create_grammar_state(
            self, world: SpiderWorld, possible_actions: List[ProductionRule],
            linking_scores: torch.Tensor,
            linked_actions_linking_scores: torch.Tensor,
            entity_types: torch.Tensor,
            entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(
                    global_action_tensors,
                    dim=0).to(global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [
                    rule.split(' -> ')[1].strip('[]\"')
                    for rule in linked_rules
                ]

                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = linking_scores[entity_ids]

                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[
                        entity_ids]

                if not self._decoder_use_graph_entities:
                    entity_type_tensor = entity_types[entity_ids]
                    entity_type_embeddings = (
                        self._entity_type_decoder_embedding(
                            entity_type_tensor).to(
                                entity_types.device).float())
                else:
                    entity_type_embeddings = entity_graph_encoding.index_select(
                        dim=0,
                        index=torch.tensor(
                            entity_ids, device=entity_graph_encoding.device))

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids), entity_action_linking_scores)
                else:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids))

        return GrammarStatelet(['statement'], translated_valid_actions,
                               self.is_nonterminal)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    def _get_linking_probabilities(
            self, worlds: List[SpiderWorld], linking_scores: torch.FloatTensor,
            question_mask: torch.LongTensor,
            entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "date_column:", followed by "number_column:", "string:", and "string_column:".
            # This is not a great assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.db_context.knowledge_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()

    @staticmethod
    def _query_difficulty(targets: torch.LongTensor, action_mapping,
                          batch_index):
        number_tables = len([
            action_mapping[(batch_index, int(a))] for a in targets
            if a >= 0 and action_mapping[(batch_index,
                                          int(a))].startswith('table_name')
        ])
        return number_tables > 1

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            '_match/exact_match': self._exact_match.get_metric(reset),
            'sql_match': self._sql_evaluator_match.get_metric(reset),
            '_others/action_similarity':
            self._action_similarity.get_metric(reset),
            '_match/match_single': self._acc_single.get_metric(reset),
            '_match/match_hard': self._acc_multi.get_metric(reset),
            'beam_hit': self._beam_hit.get_metric(reset)
        }

    @staticmethod
    def _get_type_vector(worlds: List[SpiderWorld], num_entities: int,
                         device) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        column_type_ids = [
            'boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time'
        ]

        for batch_index, world in enumerate(worlds):
            types = []

            for entity_index, entity in enumerate(
                    world.db_context.knowledge_graph.entities):
                parts = entity.split(':')
                entity_main_type = parts[0]
                if entity_main_type == 'column':
                    column_type = parts[1]
                    entity_type = column_type_ids.index(column_type)
                elif entity_main_type == 'string':
                    # cell value
                    entity_type = len(column_type_ids)
                elif entity_main_type == 'table':
                    entity_type = len(column_type_ids) + 1
                else:
                    raise (Exception("Unkown entity"))
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return torch.tensor(batch_types, dtype=torch.long,
                            device=device), entity_types

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRuleArray]],
                                    best_final_states: Mapping[
                                        int, Sequence[GrammarBasedState]],
                                    world: List[SpiderWorld],
                                    target_list: List[List[str]],
                                    outputs: Dict[str, Any]) -> None:
        batch_size = len(actions)

        outputs['predicted_sql_query'] = []

        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]

        for i in range(batch_size):
            # gold sql exactly as given
            original_gold_sql_query = ' '.join(
                world[i].get_query_without_table_hints())

            if i not in best_final_states:
                self._exact_match(0)
                self._action_similarity(0)
                self._sql_evaluator_match(0)
                self._acc_multi(0)
                self._acc_single(0)
                outputs['predicted_sql_query'].append('')
                continue

            best_action_indices = best_final_states[i][0].action_history[0]

            action_strings = [
                action_mapping[(i, action_index)]
                for action_index in best_action_indices
            ]
            predicted_sql_query = action_sequence_to_sql(action_strings,
                                                         add_table_names=True)
            outputs['predicted_sql_query'].append(
                sqlparse.format(predicted_sql_query, reindent=False))

            if target_list is not None:
                targets = target_list[i].data

                sequence_in_targets = self._action_history_match(
                    best_action_indices, targets)
                self._exact_match(sequence_in_targets)

                sql_evaluator_match = self._evaluate_func(
                    original_gold_sql_query, predicted_sql_query,
                    world[i].db_id)
                self._sql_evaluator_match(sql_evaluator_match)

                similarity = difflib.SequenceMatcher(None, best_action_indices,
                                                     targets)
                self._action_similarity(similarity.ratio())

                difficulty = self._query_difficulty(targets, action_mapping, i)
                if difficulty:
                    self._acc_multi(sql_evaluator_match)
                else:
                    self._acc_single(sql_evaluator_match)

            beam_hit = False
            for pos, final_state in enumerate(best_final_states[i]):
                action_indices = final_state.action_history[0]
                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in action_indices
                ]
                candidate_sql_query = action_sequence_to_sql(
                    action_strings, add_table_names=True)

                if target_list is not None:
                    correct = self._evaluate_func(original_gold_sql_query,
                                                  candidate_sql_query,
                                                  world[i].db_id)
                    if correct:
                        beam_hit = True
                    self._beam_hit(beam_hit)
Ejemplo n.º 30
0
class SpiderParser(SpiderBase):
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 graph_loss_lambda: float = 0.5,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 pruning_gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab, encoder, entity_encoder, question_embedder,
                         gnn_timesteps, dropout, rule_namespace)

        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias

        self._parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        self._graph_loss_lambda = graph_loss_lambda

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._graph_pruning = GraphPruning(3,
                                           self._embedding_dim,
                                           encoder.get_output_dim(),
                                           dropout,
                                           timesteps=pruning_gnn_timesteps)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

    @overrides
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        max_len_entities = max(
            [len(w.db_context.knowledge_graph.entities) for w in world])
        batch_size = len(world)
        device = utterance['tokens'].device

        oracle_entities = []
        oracle_relevance_score = None
        if action_sequence is not None:
            # we want oracle supervision for which entities should be in the query, for the loss calculation
            for batch_index, batch_actions in enumerate(
                    action_sequence.squeeze(-1)):
                oracle_entities.append(
                    set([
                        valid_actions[batch_index][action][0].split(
                            ' -> ')[1].strip('["]') for action in batch_actions
                        if not valid_actions[batch_index][action][1]
                        and action >= 0
                    ]))
            oracle_relevance_score = [
                pad_sequence_to_length(w.get_oracle_relevance_score(oe),
                                       max_len_entities)
                for w, oe in zip(world, oracle_entities)
            ]
            oracle_relevance_score = torch.tensor(oracle_relevance_score,
                                                  dtype=torch.float,
                                                  device=device)

        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        self.graph_mask = util.get_mask_from_sequence_lengths(
            torch.tensor([len(w.entities_names) for w in world],
                         device=device), max_len_entities).float()

        loss = torch.tensor([0]).float().to(device)

        if action_sequence is not None:
            graph_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                self.predicted_relevance_logits.squeeze(-1),
                oracle_relevance_score,
                reduction='none')
            graph_loss = (graph_loss *
                          self.graph_mask).sum() / self.graph_mask.sum()

            graph_loss *= self._graph_loss_lambda

            loss += graph_loss

        if self.training:
            try:
                decode_output = self._decoder_trainer.decode(
                    initial_state, self._transition_function,
                    (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))
                query_loss = decode_output['loss']
            except ZeroDivisionError:
                return {
                    'loss':
                    Parameter(torch.tensor([0]).float()).to(
                        action_sequence.device)
                }

            loss += ((1 - self._graph_loss_lambda) * query_loss)

            return {'loss': loss}
        else:
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    query_loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                    loss += query_loss
                except ZeroDivisionError:
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
            return outputs

    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # compute the relevance of each entity with the relevance GNN
        ent_relevance, ent_relevance_logits, ent_to_qst_lnk_probs = self._graph_pruning(
            worlds, encoder_outputs, entity_type_embeddings, linking_scores,
            utterance_mask, self._get_graph_adj_lists)
        # save this for loss calculation
        self.predicted_relevance_logits = ent_relevance_logits

        # multiply the embedding with the computed relevance
        graph_initial_embedding = entity_type_embeddings * ent_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self._parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state

    def _create_grammar_state(
            self, world: SpiderWorld, possible_actions: List[ProductionRule],
            linking_scores: torch.Tensor,
            linked_actions_linking_scores: torch.Tensor,
            entity_types: torch.Tensor,
            entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(
                    global_action_tensors,
                    dim=0).to(global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [
                    rule.split(' -> ')[1].strip('[]\"')
                    for rule in linked_rules
                ]

                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = linking_scores[entity_ids]

                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[
                        entity_ids]

                if not self._decoder_use_graph_entities:
                    entity_type_tensor = entity_types[entity_ids]
                    entity_type_embeddings = (
                        self._entity_type_decoder_embedding(
                            entity_type_tensor).to(
                                entity_types.device).float())
                else:
                    entity_type_embeddings = entity_graph_encoding.index_select(
                        dim=0,
                        index=torch.tensor(
                            entity_ids, device=entity_graph_encoding.device))

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids), entity_action_linking_scores)
                else:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids))

        return GrammarStatelet(['statement'], translated_valid_actions,
                               self.is_nonterminal)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            '_match/exact_match': self._exact_match.get_metric(reset),
            'sql_match': self._sql_evaluator_match.get_metric(reset),
            '_others/action_similarity':
            self._action_similarity.get_metric(reset),
            '_match/match_single': self._acc_single.get_metric(reset),
            '_match/match_hard': self._acc_multi.get_metric(reset),
            'beam_hit': self._beam_hit.get_metric(reset)
        }

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRuleArray]],
                                    best_final_states: Mapping[
                                        int, Sequence[GrammarBasedState]],
                                    world: List[SpiderWorld],
                                    target_list: List[List[str]],
                                    outputs: Dict[str, Any]) -> None:
        batch_size = len(actions)

        outputs['predicted_sql_query'] = []
        outputs['candidates'] = []

        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]

        for i in range(batch_size):
            # gold sql exactly as given
            original_gold_sql_query = ' '.join(
                world[i].get_query_without_table_hints())

            if i not in best_final_states:
                self._exact_match(0)
                self._action_similarity(0)
                self._sql_evaluator_match(0)
                self._acc_multi(0)
                self._acc_single(0)
                outputs['predicted_sql_query'].append('')
                continue

            best_action_indices = best_final_states[i][0].action_history[0]

            action_strings = [
                action_mapping[(i, action_index)]
                for action_index in best_action_indices
            ]
            predicted_sql_query = action_sequence_to_sql(action_strings,
                                                         add_table_names=True)
            outputs['predicted_sql_query'].append(
                sqlparse.format(predicted_sql_query, reindent=False))

            if target_list is not None:
                targets = target_list[i].data
            target_available = target_list is not None and targets[0] > -1

            if target_available:
                sequence_in_targets = self._action_history_match(
                    best_action_indices, targets)
                self._exact_match(sequence_in_targets)

                sql_evaluator_match = self._evaluate_func(
                    original_gold_sql_query, predicted_sql_query,
                    world[i].db_id)
                self._sql_evaluator_match(sql_evaluator_match)

                similarity = difflib.SequenceMatcher(None, best_action_indices,
                                                     targets)
                self._action_similarity(similarity.ratio())

                difficulty = self._query_difficulty(targets, action_mapping, i)
                if difficulty:
                    self._acc_multi(sql_evaluator_match)
                else:
                    self._acc_single(sql_evaluator_match)

            beam_hit = False
            candidates = []
            for pos, final_state in enumerate(best_final_states[i]):
                action_indices = final_state.action_history[0]
                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in action_indices
                ]
                candidate_sql_query = action_sequence_to_sql(
                    action_strings, add_table_names=True)

                correct = False
                if target_available:
                    correct = self._evaluate_func(original_gold_sql_query,
                                                  candidate_sql_query,
                                                  world[i].db_id)
                    if correct:
                        beam_hit = True
                    self._beam_hit(beam_hit)
                candidates.append({
                    'query':
                    action_sequence_to_sql(action_strings,
                                           add_table_names=True),
                    'correct':
                    correct
                })
            outputs['candidates'].append(candidates)
Ejemplo n.º 31
0
class NlvrDirectSemanticParser(NlvrSemanticParser):
    """
    ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. The main difference between this parser and
    ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the TransitionFunction.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 dropout: float = 0.0) -> None:
        super(NlvrDirectSemanticParser,
              self).__init__(vocab=vocab,
                             sentence_embedder=sentence_embedder,
                             action_embedding_dim=action_embedding_dim,
                             encoder=encoder,
                             dropout=dropout)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            activation=Activation.by_name('tanh')(),
            add_action_bias=False,
            dropout=dropout)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1

    @overrides
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrLanguage]],
            actions: List[List[ProductionRule]],
            identifier: List[str] = None,
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [
                        best_final_states[i][0].action_history[0]
                    ]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(
                actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings,
                                                      worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [
                        x["sentence_tokens"] for x in metadata
                    ]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs

    def _update_metrics(self, action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrLanguage]],
                        label_strings: List[List[str]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences?
        batch_size = len(worlds)
        for i in range(batch_size):
            instance_action_strings = action_strings[i]
            sequence_is_correct = [False]
            if instance_action_strings:
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                # Taking only the best sequence.
                sequence_is_correct = self._check_denotation(
                    instance_action_strings[0], instance_label_strings,
                    instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
            'consistency': self._consistency.get_metric(reset)
        }
Ejemplo n.º 32
0
class QuarelSemanticParser(Model):
    """
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
                                                          self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=False,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs

    @staticmethod
    def _get_type_vector(worlds: List[QuarelWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[QuarelWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0/self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        if self._denotation_only:
            metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)}
        else:
            metrics = {
                    'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                    'denotation_acc': self._denotation_accuracy.get_metric(reset),
                    'lf_percent': self._has_logical_form.get_metric(reset),
            }
        return metrics

    def _create_grammar_state(self,
                              world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet([START_SYMBOL],
                               translated_valid_actions,
                               type_declaration.is_nonterminal)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
class SpansText2SqlParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    span_extractor: ``SpanExtractor``, optional
        If provided, extracts spans representations based on the encoded inputs.
        The span representations are used for decoding.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 mydatabase: str,
                 schema_path: str,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 span_extractor: SpanExtractor = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)
        # span extractor, allows using spans from the source as input to the decoder
        self._span_extractor = span_extractor
        self._exact_match = Average()
        self._action_similarity = Average()

        self._valid_sql_query = SqlValidity(mydatabase=mydatabase)
        self._token_match = TokenSequenceAccuracy()
        self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path)
        self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path)

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        self.parse_sql_on_decoding = True
        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            action_sequence: torch.LongTensor = None,
            spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        spans: torch.Tensor, optional (default=None)
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of input spans that could be informative for the decoder. Comes from a ``ListField[SpanField]``
        """
        encode_outputs = self._encode(tokens, spans)
        # encode_outputs['mask'] shape: (batch_size, num_tokens, encoder_output_dim)
        batch_size = encode_outputs['mask'].size(0)
        initial_state = self._get_initial_state(
            encode_outputs['encoder_outputs'], encode_outputs['mask'],
            valid_actions)
        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            try:
                loss_output = self._decoder_trainer.decode(
                    initial_state, self._transition_function,
                    (action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
            except ZeroDivisionError as e:
                logger.info(
                    f"Input utterance in ZeroDivisionError: {[t.text for t in tokens['tokens']]}"
                )
                raise e

            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['target_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Add the target sql from the target actions for sql tokens exact match comparison
                target_sql_query = ''
                if action_sequence is not None:
                    target_action_strings = [
                        action_mapping[i][action_index]
                        for action_index in action_sequence[i].data.tolist()
                        if action_index != self._action_padding_index
                    ]
                    target_sql_query = action_sequence_to_sql(
                        target_action_strings)
                    # target_sql_query = sqlparse.format(target_sql_query, reindent=True)
                target_sql_query_for_acc = target_sql_query.split()

                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._action_similarity(0)
                    outputs['target_sql_query'].append(
                        target_sql_query_for_acc)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[i][action_index]
                    for action_index in best_action_indices
                ]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                predicted_sql_query_for_acc = predicted_sql_query.split()
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                    # predicted_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                                predicted_sql_query.split()]
                    # target_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                             target_sql_query.split()]

                    predicted_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", predicted_sql_query).split()
                    target_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", target_sql_query).split()

                    self._valid_sql_query([predicted_sql_query_for_acc],
                                          [target_sql_query_for_acc])
                    self._token_match([predicted_sql_query_for_acc],
                                      [target_sql_query_for_acc])
                    self._kb_match([predicted_sql_query_for_acc],
                                   [target_sql_query_for_acc])
                    self._schema_free_match([predicted_sql_query_for_acc],
                                            [target_sql_query_for_acc])

                outputs['best_action_sequence'].append(action_strings)
                # outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['predicted_sql_query'].append(
                    predicted_sql_query_for_acc)
                outputs['target_sql_query'].append(target_sql_query_for_acc)
                outputs['debug_info'].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs

    def _encode(self,
                tokens: Dict[str, torch.LongTensor],
                spans: torch.Tensor = None):
        """
        If spans are provided, returns the encoded spans (by self._span_extractor) instead of the
        encoded utterance tokens
        """
        outputs = {}
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        outputs['mask'] = mask
        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance,
                                                      mask))
        outputs['encoder_outputs'] = encoder_outputs
        # if spans (over the input) are given, return their representation instead of the
        # source tokens representation
        if spans is not None and self._span_extractor is not None:
            # Looking at the span start index is enough to know if
            # this is padding or not. Shape: (batch_size, num_spans)
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
            span_representations = self._span_extractor(
                encoder_outputs, spans, mask, span_mask)
            outputs["mask"] = span_mask
            outputs["encoder_outputs"] = span_representations
        return outputs

    def _get_initial_state(
            self, encoder_outputs: torch.Tensor, mask: torch.Tensor,
            actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]
        initial_sql_state = [
            SqlStatelet(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            debug_info=None)
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @staticmethod
    def get_terminals_mask(action_strings):
        terminals_mask = []
        for j, rule in enumerate(action_strings):
            lhs, rhs = rule.split('->')
            rhs_values = rhs.strip().strip('[]').split(',')
            if len(rhs_values) == 1 and rhs_values[0].strip().strip(
                    '"') != rhs_values[0].strip():
                terminals_mask.append(1)
            elif 'TABLE_PLACEHOLDER' in rhs:
                terminals_mask.append(1)
            else:
                terminals_mask.append(0)
        return terminals_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """

        validation_correct = self._exact_match._total_value  # pylint: disable=protected-access
        validation_total = self._exact_match._count  # pylint: disable=protected-access
        all_metrics = {
            '_exact_match_count':
            validation_correct,
            '_example_count':
            validation_total,
            'exact_match':
            self._exact_match.get_metric(reset),
            'sql_validity':
            self._valid_sql_query.get_metric(reset=reset)['sql_validity'],
            'action_similarity':
            self._action_similarity.get_metric(reset)
        }
        all_metrics.update(self._token_match.get_metric(reset=reset))
        all_metrics.update(self._kb_match.get_metric(reset=reset))
        all_metrics.update(self._schema_free_match.get_metric(reset=reset))
        return all_metrics

    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[
            ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append(
                    (action, i))
            else:
                raise ValueError(
                    "The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items(
        ):
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append(
                    (production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors,
                                                 dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)

                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append(
                            (action_mapping[batch_index][action], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get(
                    'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 34
0
class Text2SqlParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    """
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                            action_embedding_dim=action_embedding_dim,
                                                            input_attention=input_attention,
                                                            predict_start_type_separately=False,
                                                            add_action_bias=self._add_action_bias,
                                                            dropout=dropout)
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs


    def _get_initial_state(self,
                           encoder_outputs: torch.Tensor,
                           mask: torch.Tensor,
                           actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_utterance,
                                                 encoder_output_list,
                                                 utterance_mask_list))

        initial_grammar_state = [self._create_grammar_state(actions[i]) for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """

        validation_correct = self._exact_match._total_value # pylint: disable=protected-access
        validation_total = self._exact_match._count # pylint: disable=protected-access
        return {
                '_exact_match_count': validation_correct,
                '_example_count': validation_total,
                'exact_match': self._exact_match.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'valid_sql_query': self._valid_sql_query.get_metric(reset),
                'action_similarity': self._action_similarity.get_metric(reset)
                }

    def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append((action, i))
            else:
                raise ValueError("The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)

                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[batch_index][action], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict