def __init__(self, vocab: Vocabulary, mydatabase: str, schema_path: str, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._dropout = torch.nn.Dropout(p=dropout) self._exact_match = Average() self._action_similarity = Average() self._valid_sql_query = SqlValidity(mydatabase=mydatabase) self._token_match = TokenSequenceAccuracy() self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path) self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path) self._coverage_loss = CoverageAttentionLossMetric() # the padding value used by IndexField self._action_padding_index = -1 num_actions = vocab.get_vocab_size("rule_labels") input_action_dim = action_embedding_dim if self._add_action_bias: input_action_dim += 1 self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1) self._transition_function = BasicTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, add_action_bias=self._add_action_bias, dropout=dropout) initializer(self)
class SpansText2SqlParser(Model): """ Parameters ---------- vocab : ``Vocabulary`` utterance_embedder : ``TextFieldEmbedder`` Embedder for utterances. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. input_attention: ``Attention`` We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). span_extractor: ``SpanExtractor``, optional If provided, extracts spans representations based on the encoded inputs. The span representations are used for decoding. """ def __init__(self, vocab: Vocabulary, mydatabase: str, schema_path: str, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, span_extractor: SpanExtractor = None) -> None: super().__init__(vocab, regularizer) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._dropout = torch.nn.Dropout(p=dropout) # span extractor, allows using spans from the source as input to the decoder self._span_extractor = span_extractor self._exact_match = Average() self._action_similarity = Average() self._valid_sql_query = SqlValidity(mydatabase=mydatabase) self._token_match = TokenSequenceAccuracy() self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path) self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path) # the padding value used by IndexField self._action_padding_index = -1 num_actions = vocab.get_vocab_size("rule_labels") input_action_dim = action_embedding_dim if self._add_action_bias: input_action_dim += 1 self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1) self._transition_function = BasicTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, add_action_bias=self._add_action_bias, dropout=dropout) self.parse_sql_on_decoding = True initializer(self) @overrides def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None, spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. spans: torch.Tensor, optional (default=None) A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of input spans that could be informative for the decoder. Comes from a ``ListField[SpanField]`` """ encode_outputs = self._encode(tokens, spans) # encode_outputs['mask'] shape: (batch_size, num_tokens, encoder_output_dim) batch_size = encode_outputs['mask'].size(0) initial_state = self._get_initial_state( encode_outputs['encoder_outputs'], encode_outputs['mask'], valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. try: loss_output = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) except ZeroDivisionError as e: logger.info( f"Input utterance in ZeroDivisionError: {[t.text for t in tokens['tokens']]}" ) raise e outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['target_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Add the target sql from the target actions for sql tokens exact match comparison target_sql_query = '' if action_sequence is not None: target_action_strings = [ action_mapping[i][action_index] for action_index in action_sequence[i].data.tolist() if action_index != self._action_padding_index ] target_sql_query = action_sequence_to_sql( target_action_strings) # target_sql_query = sqlparse.format(target_sql_query, reindent=True) target_sql_query_for_acc = target_sql_query.split() # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._action_similarity(0) outputs['target_sql_query'].append( target_sql_query_for_acc) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[i][action_index] for action_index in best_action_indices ] predicted_sql_query = action_sequence_to_sql(action_strings) predicted_sql_query_for_acc = predicted_sql_query.split() if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match( best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher( None, best_action_indices, targets) self._action_similarity(similarity.ratio()) # predicted_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in # predicted_sql_query.split()] # target_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in # target_sql_query.split()] predicted_sql_query_for_acc = re.sub( r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ", r" \g<1> AS \g<1>\g<2> ", predicted_sql_query).split() target_sql_query_for_acc = re.sub( r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ", r" \g<1> AS \g<1>\g<2> ", target_sql_query).split() self._valid_sql_query([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._token_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._kb_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._schema_free_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) outputs['best_action_sequence'].append(action_strings) # outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['predicted_sql_query'].append( predicted_sql_query_for_acc) outputs['target_sql_query'].append(target_sql_query_for_acc) outputs['debug_info'].append( best_final_states[i][0].debug_info[0]) # type: ignore return outputs def _encode(self, tokens: Dict[str, torch.LongTensor], spans: torch.Tensor = None): """ If spans are provided, returns the encoded spans (by self._span_extractor) instead of the encoded utterance tokens """ outputs = {} embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() outputs['mask'] = mask # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) outputs['encoder_outputs'] = encoder_outputs # if spans (over the input) are given, return their representation instead of the # source tokens representation if spans is not None and self._span_extractor is not None: # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() span_representations = self._span_extractor( encoder_outputs, spans, mask, span_mask) outputs["mask"] = span_mask outputs["encoder_outputs"] = span_representations return outputs def _get_initial_state( self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_sql_state = [ SqlStatelet(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, debug_info=None) return initial_state @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True @staticmethod def get_terminals_mask(action_strings): terminals_mask = [] for j, rule in enumerate(action_strings): lhs, rhs = rule.split('->') rhs_values = rhs.strip().strip('[]').split(',') if len(rhs_values) == 1 and rhs_values[0].strip().strip( '"') != rhs_values[0].strip(): terminals_mask.append(1) elif 'TABLE_PLACEHOLDER' in rhs: terminals_mask.append(1) else: terminals_mask.append(0) return terminals_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track four metrics here: 1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that can be parsed. (make sure you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data, but not training data). 3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we're trying to produce a super long SQL query and run out of time steps, or something. 4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match. """ validation_correct = self._exact_match._total_value # pylint: disable=protected-access validation_total = self._exact_match._count # pylint: disable=protected-access all_metrics = { '_exact_match_count': validation_correct, '_example_count': validation_total, 'exact_match': self._exact_match.get_metric(reset), 'sql_validity': self._valid_sql_query.get_metric(reset=reset)['sql_validity'], 'action_similarity': self._action_similarity.get_metric(reset) } all_metrics.update(self._token_match.get_metric(reset=reset)) all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update(self._schema_free_match.get_metric(reset=reset)) return all_metrics def _create_grammar_state( self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append( (action, i)) else: raise ValueError( "The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items( ): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append( (production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True) @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append( (action_mapping[batch_index][action], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['utterance_attention'] = action_debug_info.get( 'question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, schema_path: str = None, attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, token_based_metric: Metric = None, span_extractor: SpanExtractor = None, sql_metrics: bool = True) -> None: super(DropSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() self._sql_metrics = schema_path is not None if self._sql_metrics: self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: self._attention = attention else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) # span extractor, allows using spans from the source as input to the decoder self._span_extractor = span_extractor
class DropSeq2Seq(Model): """ Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, to support input spans additional to input tokens Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, schema_path: str = None, attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, token_based_metric: Metric = None, span_extractor: SpanExtractor = None, sql_metrics: bool = True) -> None: super(DropSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() self._sql_metrics = schema_path is not None if self._sql_metrics: self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: self._attention = attention else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) # span extractor, allows using spans from the source as input to the decoder self._span_extractor = span_extractor def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward_on_instances( self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]: """ Takes a list of :class:`~allennlp.data.instance.Instance`s, converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) and returns the result. Before returning the result, we convert any ``torch.Tensors`` into numpy arrays and separate the batched output into a list of individual dicts per instance. Note that typically this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to :func:`forward_on_instance`. Parameters ---------- instances : List[Instance], required The instances to run the model on. cuda_device : int, required The GPU device to use. -1 means use the CPU. Returns ------- A list of the models output for each instance. """ batch_size = len(instances) with torch.no_grad(): cuda_device = self._get_prediction_device() dataset = Batch(instances) dataset.index_instances(self.vocab) model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) outputs = self.decode(self(**model_input)) instance_separated_output: List[Dict[str, numpy.ndarray]] = [ {} for _ in dataset.instances ] for name, output in list(outputs.items()): if isinstance(output, torch.Tensor): # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable. # This occurs with batch size 1, because we still want to include the loss in that case. if output.dim() == 0: output = output.unsqueeze(0) if output.size(0) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue output = output.detach().cpu().numpy() elif len(output) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue for instance_output, batch_element in zip( instance_separated_output, output): instance_output[name] = batch_element for instance_output, instance_input in zip( instance_separated_output, instances): for field in instance_input.fields: if field == 'spans' and 'source_tokens' in instance_input.fields: spans = [] source_tokens = instance_input.fields[ 'source_tokens'].tokens for indexfield in instance_input.fields[ field].field_list: spans.append( source_tokens[indexfield. span_start:indexfield.span_end + 1]) else: instance_output[field] = instance_input.fields[ field].tokens return instance_separated_output @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], spans: torch.IntTensor = None, target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. spans : ``torch.IntTensor`` A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of spans that could be informative. Comes from a ``ListField[SpanField]`` of indices into the text of the input. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens, spans) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) predicted_tokens = self.decode(output_dict)["predicted_tokens"] target_tokens_str = self.decode_target_tokens(target_tokens) if self._token_based_metric: self._token_based_metric(predicted_tokens, target_tokens_str) if self._sql_metrics: self._kb_match(predicted_tokens, target_tokens_str) self._schema_free_match(predicted_tokens, target_tokens_str) return output_dict def decode_target_tokens(self, target_tokens): target_indices = target_tokens['tokens'].detach().cpu().numpy() target_tokens_output = [] for i in range(target_indices.shape[0]): cur_target_indices = target_indices[i] cur_target_indices = list(cur_target_indices) if self._end_index in cur_target_indices: cur_target_indices = cur_target_indices[:cur_target_indices. index(self._end_index)] if self._start_index in cur_target_indices: cur_target_indices = cur_target_indices[ cur_target_indices.index(self._start_index) + 1:] target_tokens_str = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in cur_target_indices ] target_tokens_output.append(target_tokens_str) return target_tokens_output @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode(self, source_tokens: Dict[str, torch.Tensor], spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]: outputs = {} # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) outputs["source_mask"] = source_mask # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._emb_dropout(encoder_outputs) outputs["encoder_outputs"] = encoder_outputs # if spans (over the input) are given, return their representation instead of the # source tokens representation if spans is not None and self._span_extractor is not None: # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: span_mask = span_mask.unsqueeze(1) span_representations = self._span_extractor( encoder_outputs, spans, source_mask, span_mask) outputs["source_mask"] = span_mask outputs["encoder_outputs"] = span_representations return outputs def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attention_input_weights: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, max_input_sequence_length) step_attention_input_weights.append( state['input_weights'].unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) attention_input_weights = torch.cat(step_attention_input_weights, 1) output_dict = { "predictions": predictions, "attention_input_weights": attention_input_weights } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input, input_weights = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) state["input_weights"] = input_weights # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input decoder_input = self._dec_dropout(decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer( self._dec_dropout(decoder_hidden)) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input, input_weights @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update( self._token_based_metric.get_metric(reset=reset)) if self._sql_metrics: all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update( self._schema_free_match.get_metric(reset=reset)) return all_metrics