Пример #1
0
    def __init__(self,
                 vocab: Vocabulary,
                 mydatabase: str,
                 schema_path: str,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._action_similarity = Average()

        self._valid_sql_query = SqlValidity(mydatabase=mydatabase)
        self._token_match = TokenSequenceAccuracy()
        self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path)
        self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path)
        self._coverage_loss = CoverageAttentionLossMetric()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        initializer(self)
class SpansText2SqlParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    span_extractor: ``SpanExtractor``, optional
        If provided, extracts spans representations based on the encoded inputs.
        The span representations are used for decoding.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 mydatabase: str,
                 schema_path: str,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 span_extractor: SpanExtractor = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)
        # span extractor, allows using spans from the source as input to the decoder
        self._span_extractor = span_extractor
        self._exact_match = Average()
        self._action_similarity = Average()

        self._valid_sql_query = SqlValidity(mydatabase=mydatabase)
        self._token_match = TokenSequenceAccuracy()
        self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path)
        self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path)

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        self.parse_sql_on_decoding = True
        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            action_sequence: torch.LongTensor = None,
            spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        spans: torch.Tensor, optional (default=None)
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of input spans that could be informative for the decoder. Comes from a ``ListField[SpanField]``
        """
        encode_outputs = self._encode(tokens, spans)
        # encode_outputs['mask'] shape: (batch_size, num_tokens, encoder_output_dim)
        batch_size = encode_outputs['mask'].size(0)
        initial_state = self._get_initial_state(
            encode_outputs['encoder_outputs'], encode_outputs['mask'],
            valid_actions)
        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            try:
                loss_output = self._decoder_trainer.decode(
                    initial_state, self._transition_function,
                    (action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
            except ZeroDivisionError as e:
                logger.info(
                    f"Input utterance in ZeroDivisionError: {[t.text for t in tokens['tokens']]}"
                )
                raise e

            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['target_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Add the target sql from the target actions for sql tokens exact match comparison
                target_sql_query = ''
                if action_sequence is not None:
                    target_action_strings = [
                        action_mapping[i][action_index]
                        for action_index in action_sequence[i].data.tolist()
                        if action_index != self._action_padding_index
                    ]
                    target_sql_query = action_sequence_to_sql(
                        target_action_strings)
                    # target_sql_query = sqlparse.format(target_sql_query, reindent=True)
                target_sql_query_for_acc = target_sql_query.split()

                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._action_similarity(0)
                    outputs['target_sql_query'].append(
                        target_sql_query_for_acc)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[i][action_index]
                    for action_index in best_action_indices
                ]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                predicted_sql_query_for_acc = predicted_sql_query.split()
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                    # predicted_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                                predicted_sql_query.split()]
                    # target_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                             target_sql_query.split()]

                    predicted_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", predicted_sql_query).split()
                    target_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", target_sql_query).split()

                    self._valid_sql_query([predicted_sql_query_for_acc],
                                          [target_sql_query_for_acc])
                    self._token_match([predicted_sql_query_for_acc],
                                      [target_sql_query_for_acc])
                    self._kb_match([predicted_sql_query_for_acc],
                                   [target_sql_query_for_acc])
                    self._schema_free_match([predicted_sql_query_for_acc],
                                            [target_sql_query_for_acc])

                outputs['best_action_sequence'].append(action_strings)
                # outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['predicted_sql_query'].append(
                    predicted_sql_query_for_acc)
                outputs['target_sql_query'].append(target_sql_query_for_acc)
                outputs['debug_info'].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs

    def _encode(self,
                tokens: Dict[str, torch.LongTensor],
                spans: torch.Tensor = None):
        """
        If spans are provided, returns the encoded spans (by self._span_extractor) instead of the
        encoded utterance tokens
        """
        outputs = {}
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        outputs['mask'] = mask
        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance,
                                                      mask))
        outputs['encoder_outputs'] = encoder_outputs
        # if spans (over the input) are given, return their representation instead of the
        # source tokens representation
        if spans is not None and self._span_extractor is not None:
            # Looking at the span start index is enough to know if
            # this is padding or not. Shape: (batch_size, num_spans)
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
            span_representations = self._span_extractor(
                encoder_outputs, spans, mask, span_mask)
            outputs["mask"] = span_mask
            outputs["encoder_outputs"] = span_representations
        return outputs

    def _get_initial_state(
            self, encoder_outputs: torch.Tensor, mask: torch.Tensor,
            actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]
        initial_sql_state = [
            SqlStatelet(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            debug_info=None)
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @staticmethod
    def get_terminals_mask(action_strings):
        terminals_mask = []
        for j, rule in enumerate(action_strings):
            lhs, rhs = rule.split('->')
            rhs_values = rhs.strip().strip('[]').split(',')
            if len(rhs_values) == 1 and rhs_values[0].strip().strip(
                    '"') != rhs_values[0].strip():
                terminals_mask.append(1)
            elif 'TABLE_PLACEHOLDER' in rhs:
                terminals_mask.append(1)
            else:
                terminals_mask.append(0)
        return terminals_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """

        validation_correct = self._exact_match._total_value  # pylint: disable=protected-access
        validation_total = self._exact_match._count  # pylint: disable=protected-access
        all_metrics = {
            '_exact_match_count':
            validation_correct,
            '_example_count':
            validation_total,
            'exact_match':
            self._exact_match.get_metric(reset),
            'sql_validity':
            self._valid_sql_query.get_metric(reset=reset)['sql_validity'],
            'action_similarity':
            self._action_similarity.get_metric(reset)
        }
        all_metrics.update(self._token_match.get_metric(reset=reset))
        all_metrics.update(self._kb_match.get_metric(reset=reset))
        all_metrics.update(self._schema_free_match.get_metric(reset=reset))
        return all_metrics

    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[
            ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append(
                    (action, i))
            else:
                raise ValueError(
                    "The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items(
        ):
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append(
                    (production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors,
                                                 dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)

                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append(
                            (action_mapping[batch_index][action], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get(
                    'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 schema_path: str = None,
                 attention: Attention = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 token_based_metric: Metric = None,
                 span_extractor: SpanExtractor = None,
                 sql_metrics: bool = True) -> None:
        super(DropSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
        else:
            self._token_based_metric = TokenSequenceAccuracy()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            self._schema_free_match = GlobalTemplAccuracy(
                schema_path=schema_path)
            self._kb_match = KnowledgeBaseConstsAccuracy(
                schema_path=schema_path)

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            self._attention = attention
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        # span extractor, allows using spans from the source as input to the decoder
        self._span_extractor = span_extractor
class DropSeq2Seq(Model):
    """
    Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, to support input spans additional to input tokens

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 schema_path: str = None,
                 attention: Attention = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 token_based_metric: Metric = None,
                 span_extractor: SpanExtractor = None,
                 sql_metrics: bool = True) -> None:
        super(DropSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
        else:
            self._token_based_metric = TokenSequenceAccuracy()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            self._schema_free_match = GlobalTemplAccuracy(
                schema_path=schema_path)
            self._kb_match = KnowledgeBaseConstsAccuracy(
                schema_path=schema_path)

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            self._attention = attention
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        # span extractor, allows using spans from the source as input to the decoder
        self._span_extractor = span_extractor

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward_on_instances(
            self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]:
        """
        Takes a list of  :class:`~allennlp.data.instance.Instance`s, converts that text into
        arrays using this model's :class:`Vocabulary`, passes those arrays through
        :func:`self.forward()` and :func:`self.decode()` (which by default does nothing)
        and returns the result.  Before returning the result, we convert any
        ``torch.Tensors`` into numpy arrays and separate the
        batched output into a list of individual dicts per instance. Note that typically
        this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to
        :func:`forward_on_instance`.

        Parameters
        ----------
        instances : List[Instance], required
            The instances to run the model on.
        cuda_device : int, required
            The GPU device to use.  -1 means use the CPU.

        Returns
        -------
        A list of the models output for each instance.
        """
        batch_size = len(instances)
        with torch.no_grad():
            cuda_device = self._get_prediction_device()
            dataset = Batch(instances)
            dataset.index_instances(self.vocab)
            model_input = util.move_to_device(dataset.as_tensor_dict(),
                                              cuda_device)
            outputs = self.decode(self(**model_input))

            instance_separated_output: List[Dict[str, numpy.ndarray]] = [
                {} for _ in dataset.instances
            ]
            for name, output in list(outputs.items()):
                if isinstance(output, torch.Tensor):
                    # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable.
                    # This occurs with batch size 1, because we still want to include the loss in that case.
                    if output.dim() == 0:
                        output = output.unsqueeze(0)

                    if output.size(0) != batch_size:
                        self._maybe_warn_for_unseparable_batches(name)
                        continue
                    output = output.detach().cpu().numpy()
                elif len(output) != batch_size:
                    self._maybe_warn_for_unseparable_batches(name)
                    continue
                for instance_output, batch_element in zip(
                        instance_separated_output, output):
                    instance_output[name] = batch_element

            for instance_output, instance_input in zip(
                    instance_separated_output, instances):
                for field in instance_input.fields:
                    if field == 'spans' and 'source_tokens' in instance_input.fields:
                        spans = []
                        source_tokens = instance_input.fields[
                            'source_tokens'].tokens
                        for indexfield in instance_input.fields[
                                field].field_list:
                            spans.append(
                                source_tokens[indexfield.
                                              span_start:indexfield.span_end +
                                              1])
                    else:
                        instance_output[field] = instance_input.fields[
                            field].tokens

            return instance_separated_output

    @overrides
    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        spans: torch.IntTensor = None,
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        spans : ``torch.IntTensor``
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of spans that could be informative. Comes from a ``ListField[SpanField]`` of
            indices into the text of the input.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens, spans)

        if target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens:
                if self._bleu:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    self._bleu(best_predictions, target_tokens["tokens"])

                predicted_tokens = self.decode(output_dict)["predicted_tokens"]
                target_tokens_str = self.decode_target_tokens(target_tokens)

                if self._token_based_metric:
                    self._token_based_metric(predicted_tokens,
                                             target_tokens_str)
                if self._sql_metrics:
                    self._kb_match(predicted_tokens, target_tokens_str)
                    self._schema_free_match(predicted_tokens,
                                            target_tokens_str)

        return output_dict

    def decode_target_tokens(self, target_tokens):
        target_indices = target_tokens['tokens'].detach().cpu().numpy()
        target_tokens_output = []
        for i in range(target_indices.shape[0]):
            cur_target_indices = target_indices[i]
            cur_target_indices = list(cur_target_indices)
            if self._end_index in cur_target_indices:
                cur_target_indices = cur_target_indices[:cur_target_indices.
                                                        index(self._end_index)]
            if self._start_index in cur_target_indices:
                cur_target_indices = cur_target_indices[
                    cur_target_indices.index(self._start_index) + 1:]
            target_tokens_str = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace)
                for x in cur_target_indices
            ]
            target_tokens_output.append(target_tokens_str)

        return target_tokens_output

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(self,
                source_tokens: Dict[str, torch.Tensor],
                spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        outputs = {}
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        outputs["source_mask"] = source_mask
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        encoder_outputs = self._emb_dropout(encoder_outputs)
        outputs["encoder_outputs"] = encoder_outputs
        # if spans (over the input) are given, return their representation instead of the
        # source tokens representation
        if spans is not None and self._span_extractor is not None:
            # Looking at the span start index is enough to know if
            # this is padding or not. Shape: (batch_size, num_spans)
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
            if span_mask.dim() == 1:
                span_mask = span_mask.unsqueeze(1)
            span_representations = self._span_extractor(
                encoder_outputs, spans, source_mask, span_mask)
            outputs["source_mask"] = span_mask
            outputs["encoder_outputs"] = span_representations
        return outputs

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
            self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        step_attention_input_weights: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, max_input_sequence_length)
            step_attention_input_weights.append(
                state['input_weights'].unsqueeze(1))

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        # shape: (batch_size, num_decoding_steps, max_input_sequence_length)
        attention_input_weights = torch.cat(step_attention_input_weights, 1)

        output_dict = {
            "predictions": predictions,
            "attention_input_weights": attention_input_weights
        }

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input, input_weights = self._prepare_attended_input(
                decoder_hidden, encoder_outputs, source_mask)
            state["input_weights"] = input_weights

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        decoder_input = self._dec_dropout(decoder_input)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(
            self._dec_dropout(decoder_hidden))

        return output_projections, state

    def _prepare_attended_input(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(decoder_hidden_state, encoder_outputs,
                                        encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input, input_weights

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._bleu:
                all_metrics.update(self._bleu.get_metric(reset=reset))
            all_metrics.update(
                self._token_based_metric.get_metric(reset=reset))
            if self._sql_metrics:
                all_metrics.update(self._kb_match.get_metric(reset=reset))
                all_metrics.update(
                    self._schema_free_match.get_metric(reset=reset))
        return all_metrics