Ejemplo n.º 1
0
class STSBTask(Task):
    ''' Task class for Sentence Textual Similarity Benchmark.  '''
    def __init__(self, path, max_seq_len, name="sts_benchmark"):
        ''' '''
        super(STSBTask, self).__init__(name, 1)
        self.categorical = 0
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = Average()
        self.load_data(path, max_seq_len)

    def load_data(self, path, max_seq_len):
        ''' '''
        tr_data = load_tsv(os.path.join(path, 'train.tsv'), max_seq_len, skip_rows=1,
                           s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5)
        val_data = load_tsv(os.path.join(path, 'dev.tsv'), max_seq_len, skip_rows=1,
                            s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=7, s2_idx=8, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data
        log.info("\tFinished loading STS Benchmark data.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'spearmanr': self.scorer2.get_metric(reset)}
Ejemplo n.º 2
0
class STSBTask(Task):
    ''' Task class for Sentence Textual Similarity Benchmark.  '''
    def __init__(self, path, max_seq_len, name="sts_benchmark"):
        ''' '''
        super(STSBTask, self).__init__(name, 1)
        self.categorical = 0
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = Average()
        self.load_data(path, max_seq_len)

    def load_data(self, path, max_seq_len):
        ''' '''
        tr_data = load_tsv(os.path.join(path, 'train.tsv'), max_seq_len, skip_rows=1,
                           s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5)
        val_data = load_tsv(os.path.join(path, 'dev.tsv'), max_seq_len, skip_rows=1,
                            s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=7, s2_idx=8, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data
        log.info("\tFinished loading STS Benchmark data.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'spearmanr': self.scorer2.get_metric(reset)}
Ejemplo n.º 3
0
class CoLATask(Task):
    '''Class for Warstdadt acceptability task'''
    def __init__(self, path, max_seq_len, name="acceptability"):
        ''' '''
        super(CoLATask, self).__init__(name, 2)
        self.pair_input = 0
        self.load_data(path, max_seq_len)
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = CategoricalAccuracy()

    def load_data(self, path, max_seq_len):
        '''Load the data'''
        tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len,
                           s1_idx=3, s2_idx=None, targ_idx=1)
        val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len,
                            s1_idx=3, s2_idx=None, targ_idx=1)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data
        log.info("\tFinished loading CoLA.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'acc': self.scorer2.get_metric(reset)}
Ejemplo n.º 4
0
class CoLATask(Task):
    '''Class for Warstdadt acceptability task'''
    def __init__(self, path, max_seq_len, name="acceptability"):
        ''' '''
        super(CoLATask, self).__init__(name, 2)
        self.pair_input = 0
        self.load_data(path, max_seq_len)
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = CategoricalAccuracy()

    def load_data(self, path, max_seq_len):
        '''Load the data'''
        tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len,
                           s1_idx=3, s2_idx=None, targ_idx=1)
        val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len,
                            s1_idx=3, s2_idx=None, targ_idx=1)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data
        log.info("\tFinished loading CoLA.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'acc': self.scorer2.get_metric(reset)}
Ejemplo n.º 5
0
class Seq2SeqClassifier(Model):
  def __init__(self, word_embeddings: TextFieldEmbedder, encoder: Seq2SeqEncoder, vocab: Vocabulary, hidden_dimension: int, bs: int) -> None:
    super().__init__(vocab)
    self.word_embeddings = word_embeddings
    self.encoder = encoder
    self.bs = bs
    self.hidden_dim = hidden_dimension
    self.vocab = vocab
    self.tasks_vocabulary = {"default": vocab}
    self.current_task = "default"
    self.num_task = 0
    self.classification_layers = torch.nn.ModuleList([torch.nn.Linear(in_features=self.hidden_dim, out_features=self.vocab.get_vocab_size('labels'))])
    self.task2id = { "default": 0 }
    self.hidden2tag = self.classification_layers[self.task2id["default"]]
    self.accuracy = CategoricalAccuracy()
    self.loss_function = torch.nn.CrossEntropyLoss()
    self.average  = Average()
    self.activations = []
    self.labels = []

  def add_task(self, task_tag: str, vocab: Vocabulary):
    self.classification_layers.append(torch.nn.Linear(in_features=self.hidden_dim, out_features=vocab.get_vocab_size('labels')))
    self.num_task = self.num_task + 1
    self.task2id[task_tag] = self.num_task
    self.tasks_vocabulary[task_tag] = vocab

  def set_task(self, task_tag: str):
    self.hidden2tag = self.classification_layers[self.task2id[task_tag]]
    self.current_task = task_tag
    self.vocab = self.tasks_vocabulary[task_tag]

  def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor = None) -> Dict[str, torch.Tensor]:

    mask = get_text_field_mask(tokens)
    embeddings = self.word_embeddings(tokens)
    encoder_out = self.encoder(embeddings, mask)
    tag_logits = self.hidden2tag(torch.nn.functional.adaptive_max_pool1d(encoder_out.permute(0,2,1), (1,)).view(-1, self.hidden_dim))
    output = {'logits': tag_logits }
    self.activations = encoder_out
    self.labels = label

    if label is not None:
      _, preds = tag_logits.max(dim=1)
      self.average(matthews_corrcoef(label.data.cpu().numpy(), preds.data.cpu().numpy()))
      self.accuracy(tag_logits, label)
      output["loss"] = self.loss_function(tag_logits, label)

    return output

  def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    return {"accuracy": self.accuracy.get_metric(reset), "average": self.average.get_metric(reset)}

  def get_activations(self) -> []:
    return self.activations, self.labels
Ejemplo n.º 6
0
def perplexity(lm_path, sample_file):
    import kenlm
    model = kenlm.LanguageModel(lm_path)
    ppl = Average()
    num_lines = file_utils.get_num_lines(sample_file)
    with open(sample_file) as sf:
        with tqdm(sf, total=num_lines, desc='Computing PPL') as pbar:
            for sentence in pbar:
                ppl(model.perplexity(sentence))
                pbar.set_postfix({'PPL': ppl.get_metric()})

    logger.info(f'PPL for file {sample_file} = {ppl.get_metric()}')
Ejemplo n.º 7
0
class HatefulMemeModel(Model):
    def __init__(self, vocab: Vocabulary, text_model_name: str):
        super().__init__(vocab)
        self._text_model = BertForSequenceClassification.from_pretrained(
            text_model_name)
        self._num_labels = vocab.get_vocab_size()

        self._accuracy = Average()
        self._auc = Auc()

        self._softmax = torch.nn.Softmax(dim=1)

    def forward(
        self,
        source_tokens: TextFieldTensors,
        box_features: Optional[Tensor] = None,
        box_coordinates: Optional[Tensor] = None,
        box_mask: Optional[Tensor] = None,
        label: Optional[Tensor] = None,
        metadata: Optional[Dict] = None,
    ) -> Dict[str, torch.Tensor]:
        input_ids = source_tokens["tokens"]["token_ids"]
        input_mask = source_tokens["tokens"]["mask"]
        token_type_ids = source_tokens["tokens"]["type_ids"]
        outputs = self._text_model(
            input_ids=input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
            labels=label,
        )

        if label is not None:
            predictions = torch.argmax(self._softmax(outputs.logits), dim=-1)
            for index in range(predictions.shape[0]):
                correct = float((predictions[index] == label[index]))
                self._accuracy(int(correct))

            self._auc(predictions, label)

        return outputs

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not self.training:
            metrics["accuracy"] = self._accuracy.get_metric(reset=reset)
            metrics["auc"] = self._auc.get_metric(reset=reset)
        return metrics
class NamesClassifier(Model):
    """
    NamesClassifier that takes in a name and label. 
    It performs forward passes and uptades the metrics such as loss and accuracy.
    """
    def __init__(self, char_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder, vocab: Vocabulary) -> None:
        super().__init__(vocab)
        # Initialize embedding vector.
        self.char_embeddings = char_embeddings
        # Initialize encode
        self.encoder = encoder
        # Initialize hidden-tag layer.
        # It outputs score  for wach label.
        self.hidden2tag = torch.nn.Linear(
            in_features=encoder.get_output_dim(),
            out_features=vocab.get_vocab_size('labels'))
        # Initialize the average metric.
        self.accuracy = Average()
        # it’s faster and has better numerical properties compared to Softmax
        self.m = LogSoftmax()
        # The negative log likelihood loss. It is useful to train a
        # classification problem with `C` classes
        self.loss = NLLLoss()

    @overrides
    def forward(self,
                name: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        # To ignore the some sepcific indices, create a mask
        mask = get_text_field_mask(name)
        # Craete embeddings given a name
        embeddings = self.char_embeddings(name)
        # Encode the embeddings with mask
        encoder_out = self.encoder(embeddings, mask)
        # Calculate the logit scores
        tag_logits = self.hidden2tag(encoder_out)
        # Update the metrics and return output
        output = {"tag_logits": tag_logits}
        if label is not None:
            output["loss"] = self.loss(self.m(tag_logits), label)
            prediction = tag_logits.max(1)[1]
            self.accuracy(prediction.eq(label).double().mean())
        return output

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        # Simply return accuracy after each pass
        return {"accuracy": float(self.accuracy.get_metric(reset))}
Ejemplo n.º 9
0
class MajorityClassifier(Model):
  def __init__(self, vocab: Vocabulary) -> None:
    super().__init__(vocab)
    self.vocab = vocab
    self.current_task = "default"
    self.tasks_vocabulary = {"default": vocab}
    self.classification_layers = torch.nn.ModuleList([torch.nn.Linear(in_features=10, out_features=self.vocab.get_vocab_size('labels'))])
    self.loss_function = torch.nn.CrossEntropyLoss()
    self.accuracy = CategoricalAccuracy()
    self.average  = Average()

  def add_task(self, task_tag: str, vocab: Vocabulary):
    self.tasks_vocabulary[task_tag] = vocab

  def set_task(self, task_tag: str):
    self.current_task = task_tag
    self.vocab = self.tasks_vocabulary[task_tag]

  def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor=None) -> Dict[str, torch.Tensor]:
   logi=np.zeros(2)
   if self.current_task == "cola":
     logi=np.zeros(2)
     logi[self.vocab.get_token_index("1", "labels")] = 1
     logi=torch.Tensor(np.repeat([logi], label.size(),0))
   elif self.current_task == "trec":
     logi=np.zeros(6)
     logi[self.vocab.get_token_index("ENTY", "labels")] = 1
     logi=torch.Tensor(np.repeat([logi], label.size(),0))
   elif self.current_task == "sst":
     logi=np.zeros(5)
     logi[self.vocab.get_token_index("3", "labels")] = 1
     logi=torch.Tensor(np.repeat([logi], label.size(),0))
   elif self.current_task == "subjectivity":
     logi=np.zeros(2)
     logi[self.vocab.get_token_index("SUBJECTIVE", "labels")] = 1
     logi=torch.Tensor(np.repeat([logi], label.size(0),0))
   output = {}
   #print("Going foward , do we have labels", label)
   if label is not None:
     _, preds = logi.max(dim=1)
     self.average(matthews_corrcoef(label.data.cpu().numpy(), preds.data.cpu().numpy()))
     self.accuracy(logi, label)
     output["loss"] = torch.tensor([0])
   return output

  def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    return {"accuracy": self.accuracy.get_metric(reset), "average": self.average.get_metric(reset)}
Ejemplo n.º 10
0
def perplexity(lm_path, pickle_path):
    import kenlm
    model = kenlm.LanguageModel(lm_path)
    ppl = Average()
    with open(pickle_path, 'rb') as sf:
        dialog_dict = pickle.load(sf)
        num_responses = len(dialog_dict) - 2
        responses_list = []
        for rno in range(num_responses):
            responses = [response for response in dialog_dict[f'response_{rno+1}']]
            responses_list.append(responses)
        flattened_responses = [response for responses in responses_list for response in responses]
        with tqdm(flattened_responses, desc='Computing PPL') as pbar:
            for sentence in pbar:
                ppl(model.perplexity(sentence))
                pbar.set_postfix({'PPL': ppl.get_metric()})

    logger.info(f'PPL for file {pickle_path} = {ppl.get_metric()}')
Ejemplo n.º 11
0
class Decoder(torch.nn.Module, Registrable):
    """``Decoder`` class is a wrapper for different decoders."""
    def __init__(self, vocab: Vocabulary):
        super().__init__()  # type: ignore
        self.vocab = vocab
        self._nll = Average()
        self._ppl = WordPPL()
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index(
            self.vocab._padding_token)  # noqa: WPS437
        self._bleu = BLEU(exclude_indices={
            self._pad_index, self._end_index, self._start_index
        })

    def forward(
        self,
        encoder_outs: Dict[str, Any],
        target_tokens: Dict[str, torch.LongTensor],
    ) -> Dict[str, torch.Tensor]:
        """Run the module's forward function."""
        raise NotImplementedError

    def post_process(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Post process after decoding."""
        raise NotImplementedError

    def get_metrics(self, reset: bool = False):
        """Collect all available metrics."""
        all_metrics: Dict[str, float] = {}
        if not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        all_metrics.update({'nll': float(self._nll.get_metric(reset=reset))})
        all_metrics.update({'_ppl': float(self._ppl.get_metric(reset=reset))})
        return all_metrics
Ejemplo n.º 12
0
class MyModel(Model):
    def __init__(
        self,
        text_field_embedder: TextFieldEmbedder,
        vocab: Vocabulary,
        seq2vec_encoder: Seq2VecEncoder = None,
        dropout: float = None,
        regularizer: RegularizerApplicator = None,
    ):
        super().__init__(vocab, regularizer)

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None
        self.sym_size = sym_size
        self.embeddings = text_field_embedder
        self.vec_encoder = seq2vec_encoder
        self.hidden_dim = self.vec_encoder.get_output_dim()
        self.linear_class = torch.nn.Linear(self.hidden_dim, self.sym_size)
        # self.f_linear = torch.nn.Linear(self.hidden_dim * 2, self.hidden_dim * 2)
        self.dim = [12, 62, 4, 40, 62]
        self.true_list = [Average() for i in range(5)]
        self.pre_total = [Average() for i in range(5)]
        self.pre_true = [Average() for i in range(5)]
        self.total_pre = Average()
        self.total_true = Average()
        self.total_pre_true = Average()
        self.total_future_true = Average()
        self.macro_f = MacroF(self.sym_size)
        self.turn_acc = Average()

        self.future_acc = Average()

    def forward(self, text, label, **args):
        bs = label.size(0)
        embeddings = self.embeddings(text)  # bs  * seq_len * embedding

        mask = get_text_field_mask(text)  # bs * sen_num * sen_len

        seq_hidden = self.vec_encoder(embeddings, mask)  # bs , embedding
        # Shape: (batch_size, num_labels)
        topic_probs = F.sigmoid(self.linear_class(seq_hidden))
        # topic_weight = torch.ones_like(label) + 2 * label
        topic_weight = torch.ones_like(label) + label * 4
        loss = F.binary_cross_entropy(topic_probs, label.float(),
                                      topic_weight.float())
        # loss = F.binary_cross_entropy(topic_probs, label.long(), topic_weight.long())
        output_dict = {
            'loss': loss,
            'probs': topic_probs,
            'last_hidden': seq_hidden
        }
        # _, max_index = torch.max(topic_probs, -1)
        total_pre_list = []
        total_true_list = []
        total_pre_true_list = []
        pre_index = (topic_probs > 0.5).long()
        # pre_index = (topic_probs > 0.5).float()
        total_pre = torch.sum(pre_index)
        total_true = torch.sum(label)
        mask_index = (label == 1).long()
        # mask_index = (label == 1).float()
        self.macro_f(pre_index.cpu(), label.cpu())
        true_positive = (pre_index == label).long() * mask_index
        # true_positive = (pre_index == label).float() * mask_index
        st = 0
        for i in range(5):
            total_pre_list.append(torch.sum(pre_index[:, st:st + self.dim[i]]))
            total_true_list.append(torch.sum(label[:, st:st + self.dim[i]]))
            total_pre_true_list.append(
                torch.sum(true_positive[:, st:st + self.dim[i]]))
            st += self.dim[i]

        turn_true_num = (torch.sum(true_positive,
                                   1) == torch.sum(mask_index, 1)).long()
        # turn_true_num = (torch.sum(true_positive, 1) == torch.sum(mask_index, 1)).float()
        self.turn_acc(torch.sum(turn_true_num).item() / bs)
        pre_true = torch.sum(true_positive)

        self.total_pre(total_pre.float().item())
        self.total_true(total_true.float().item())
        self.total_pre_true(pre_true.float().item())
        self.total_future_true(
            torch.sum((pre_index == args['future']).long() *
                      (args['future'] == 1).long()).item())
        # self.total_future_true(torch.sum((pre_index == args['future']) * (args['future'] == 1).float()).item())
        for i in range(5):
            self.pre_total[i](total_pre_list[i].float().item())
            self.pre_true[i](total_pre_true_list[i].float().item())
            self.true_list[i](total_true_list[i].float().item())

        return output_dict

    def get_metrics(self, reset=False):
        metrics = {}
        total_pre = self.total_pre.get_metric(reset=reset)
        total_pre_true = self.total_pre_true.get_metric(reset=reset)
        total_true = self.total_true.get_metric(reset=reset)
        total_futuer_true = self.total_future_true.get_metric(reset=reset)
        for i in range(5):
            pre_i = self.pre_total[i].get_metric(reset=reset)
            pre_true_i = self.pre_true[i].get_metric(reset=reset)
            true_i = self.true_list[i].get_metric(reset=reset)
            acc_i, rec_i, f_i = 0., 0., 0.
            if pre_i > 0:
                acc_i = pre_true_i / pre_i
            if true_i > 0:
                rec_i = pre_true_i / true_i
            if acc_i + rec_i > 0:
                f_i = 2 * acc_i * rec_i / (acc_i + rec_i)
            metrics['f1' + str(i)] = f_i
            metrics['rc' + str(i)] = rec_i
            metrics['ac' + str(i)] = acc_i
        acc, rec, f1, facc = 0., 0., 0., 0.
        if total_pre > 0:
            acc = total_pre_true / total_pre
            facc = total_futuer_true / total_pre
        if total_true > 0:
            rec = total_pre_true / total_true
        if acc + rec > 0:
            f1 = 2 * acc * rec / (acc + rec)

        metrics['acc'] = acc
        metrics['rec'] = rec
        metrics['f1'] = f1
        metrics['macro_f1'] = self.macro_f.get_metric(reset=reset)
        metrics['turn_acc'] = self.turn_acc.get_metric(reset=reset)
        metrics['future_acc'] = facc
        return metrics
Ejemplo n.º 13
0
class WikiTablesSemanticParser(Model):
    """
    A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question,
    and produces a logical form that answers the question when executed over the table.  The
    logical form is generated by a `type-constrained`, `transition-based` parser. This is an
    abstract class that defines most of the functionality related to the transition-based parser. It
    does not contain the implementation for actually training the parser. You may want to train it
    using a learning-to-search algorithm, in which case you will want to use
    ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the
    correct denotation, you will want to use ``WikiTablesMmlSemanticParser``.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables.
    """
    # pylint: disable=abstract-method
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

    def _get_initial_state_and_scores(self,
                                      question: Dict[str, torch.LongTensor],
                                      table: Dict[str, torch.LongTensor],
                                      world: List[WikiTablesWorld],
                                      actions: List[List[ProductionRuleArray]],
                                      example_lisp_string: List[str] = None,
                                      add_world_to_initial_state: bool = False,
                                      checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               output_action_embeddings=output_action_embeddings,
                                               action_biases=action_biases,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               world=initial_state_world,
                                               example_lisp_string=example_lisp_string,
                                               checklist_state=checklist_states,
                                               debug_info=None)
        return {"initial_state": initial_state,
                "linking_scores": linking_scores,
                "feature_scores": feature_scores,
                "similarity_scores": question_entity_similarity_max_score}

    @staticmethod
    def _get_neighbor_indices(worlds: List[WikiTablesWorld],
                              num_entities: int,
                              tensor: torch.Tensor) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(neighbor_indexes,
                                                      num_entities,
                                                      lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        return tensor.new_tensor(batch_neighbors, dtype=torch.long)

    @staticmethod
    def _get_type_vector(worlds: List[WikiTablesWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the one hot encoding for each entity's type. In addition,
        a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(one_hot_vectors[entity_type])

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: [0, 0, 0, 0])
            batch_types.append(padded)
        return tensor.new_tensor(batch_types), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[WikiTablesWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. dpd_acc, which is the percentage of the time that our best output action sequence is
            in the set of action sequences provided by DPD.  This is an easy-to-compute lower bound
            on denotation accuracy for the set of examples where we actually have DPD output.  We
            only score dpd_acc on that subset.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that has DPD output (make sure
            you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        return {
                'dpd_acc': self._action_sequence_accuracy.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'lf_percent': self._has_logical_form.get_metric(reset),
                }

    @staticmethod
    def _create_grammar_state(world: WikiTablesWorld,
                              possible_actions: List[ProductionRuleArray]) -> GrammarState:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_string = action[0]
            action_mapping[action_string] = i
        translated_valid_actions = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = [action_mapping[action_string]
                                             for action_string in action_strings]
        return GrammarState([START_SYMBOL],
                            {},
                            translated_valid_actions,
                            action_mapping,
                            type_declaration.is_nonterminal)

    def _embed_actions(self, actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor,
                                                                                torch.Tensor,
                                                                                torch.Tensor,
                                                                                Dict[Tuple[int, int], int]]:
        """
        Given all of the possible actions for all batch instances, produce an embedding for them.
        There will be significant overlap in this list, as the production rules from the grammar
        are shared across all batch instances.  Our returned tensor has an embedding for each
        `unique` action, so we also need to return a mapping from the original ``(batch_index,
        action_index)`` to our new ``global_action_index``, so that we can get the right action
        embedding during decoding.

        Returns
        -------
        action_embeddings : ``torch.Tensor``
            Has shape ``(num_unique_actions, action_embedding_dim)``.
        output_action_embeddings : ``torch.Tensor``
            Has shape ``(num_unique_actions, action_embedding_dim)``.
        action_biases : ``torch.Tensor``
            Has shape ``(num_unique_actions, 1)``.
        action_map : ``Dict[Tuple[int, int], int]``
            Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in
            the ``action_embeddings`` tensor.  All non-embeddable actions get mapped to `-1` here.
        """
        # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it
        # needs to be.  We used to embed actions differently (using some crazy ideas about
        # embedding the LHS and RHS separately); we could probably get away with simplifying things
        # further now that we're just doing a simple embedding for global actions.  But I'm leaving
        # it like this for now to have a minimal change to go from the LHS/RHS embedding to a
        # single action embedding.
        embedded_actions = self._action_embedder.weight
        output_embedded_actions = self._output_action_embedder.weight
        action_biases = self._action_biases.weight

        # Now we just need to make a map from `(batch_index, action_index)` to
        # `global_action_index`.  global_action_ids has the list of all unique actions; here we're
        # going over all of the actions for each batch instance so we can map them to the global
        # action ids.
        action_vocab = self.vocab.get_token_to_index_vocabulary(self._rule_namespace)
        action_map: Dict[Tuple[int, int], int] = {}
        for batch_index, instance_actions in enumerate(actions):
            for action_index, action in enumerate(instance_actions):
                if not action[0]:
                    # This rule is padding.
                    continue
                global_action_id = action_vocab.get(action[0], -1)
                action_map[(batch_index, action_index)] = global_action_id
        return embedded_actions, output_embedded_actions, action_biases, action_map

    @staticmethod
    def _map_entity_productions(linking_scores: torch.FloatTensor,
                                worlds: List[WikiTablesWorld],
                                actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor,
                                                                                   Dict[Tuple[int, int], int]]:
        """
        Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``.
        That is, some actions correspond to terminal productions of entities from our table.  We
        need to find those actions and map them to their corresponding entity indices, where the
        entity index is its position in the list of entities returned by the ``world``.  This list
        is what defines the second dimension of the ``linking_scores`` tensor, so we can use this
        index to look up linking scores for each action in that tensor.

        For easier processing later, the mapping that we return is `flattened` - we really want to
        map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to
        have to use the result of this mapping to do ``index_selects`` on the ``linking_scores``
        tensor.  You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to
        have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices
        into this flattened tensor.

        Parameters
        ----------
        linking_scores : ``torch.Tensor``
            A tensor representing linking scores between each table entity and each question token.
            Has shape ``(batch_size, num_entities, num_question_tokens)``.
        worlds : ``List[WikiTablesWorld]``
            The ``World`` for each batch instance.  The ``World`` contains a reference to the
            ``TableKnowledgeGraph`` that defines the set of entities in the linking.
        actions : ``List[List[ProductionRuleArray]]``
            The list of possible actions for each batch instance.  Our action indices are defined
            in terms of this list, so we'll find entity productions in this list and map them to
            entity indices from the entity list we get from the ``World``.

        Returns
        -------
        flattened_linking_scores : ``torch.Tensor``
            A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities,
            num_question_tokens)``.
        actions_to_entities : ``Dict[Tuple[int, int], int]``
            A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``,
            representing which action indices correspond to which entity indices in the returned
            ``flattened_linking_scores`` tensor.
        """
        batch_size, num_entities, num_question_tokens = linking_scores.size()
        entity_map: Dict[Tuple[int, str], int] = {}
        for batch_index, world in enumerate(worlds):
            for entity_index, entity in enumerate(world.table_graph.entities):
                entity_map[(batch_index, entity)] = batch_index * num_entities + entity_index
        actions_to_entities: Dict[Tuple[int, int], int] = {}
        for batch_index, action_list in enumerate(actions):
            for action_index, action in enumerate(action_list):
                if not action[0]:
                    # This action is padding.
                    continue
                _, production = action[0].split(' -> ')
                entity_index = entity_map.get((batch_index, production), None)
                if entity_index is not None:
                    actions_to_entities[(batch_index, action_index)] = entity_index
        flattened_linking_scores = linking_scores.view(batch_size * num_entities, num_question_tokens)
        return flattened_linking_scores, actions_to_entities

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 14
0
class KglmDisc(Model):
    """
    Knowledge graph language model discriminator (for importance sampling).

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 entity_embedder: TextFieldEmbedder,
                 relation_embedder: TextFieldEmbedder,
                 knowledge_graph_path: str,
                 use_shortlist: bool,
                 hidden_size: int,
                 num_layers: int,
                 cutoff: int = 30,
                 tie_weights: bool = False,
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(KglmDisc, self).__init__(vocab)

        # We extract the `Embedding` layers from the `TokenEmbedders` to apply dropout later on.
        # pylint: disable=protected-access
        self._token_embedder = token_embedder._token_embedders['tokens']
        self._entity_embedder = entity_embedder._token_embedders['entity_ids']
        self._relation_embedder = relation_embedder._token_embedders['relations']
        self._recent_entities = RecentEntities(cutoff=cutoff)
        self._knowledge_graph_lookup = KnowledgeGraphLookup(knowledge_graph_path, vocab=vocab)
        self._use_shortlist = use_shortlist
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._cutoff = cutoff
        self._tie_weights = tie_weights

        # Dropout
        self._locked_dropout = LockedDropout()
        self._dropout = dropout
        self._dropouth = dropouth
        self._dropouti = dropouti
        self._dropoute = dropoute
        self._wdrop = wdrop

        # Regularization strength
        self._alpha = alpha
        self._beta = beta

        # RNN Encoders.
        entity_embedding_dim = entity_embedder.get_output_dim()
        token_embedding_dim = token_embedder.get_output_dim()
        self.entity_embedding_dim = entity_embedding_dim
        self.token_embedding_dim = token_embedding_dim

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = token_embedding_dim
            else:
                input_size = hidden_size
            if i == num_layers - 1:
                output_size = token_embedding_dim + 2 * entity_embedding_dim
            else:
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        # Various linear transformations.
        self._fc_mention_type = torch.nn.Linear(
            in_features=token_embedding_dim,
            out_features=4)

        if not use_shortlist:
            self._fc_new_entity = torch.nn.Linear(
                in_features=entity_embedding_dim,
                out_features=vocab.get_vocab_size('entity_ids'))

            if tie_weights:
                self._fc_new_entity.weight = self._entity_embedder.weight

        self._state: Optional[Dict[str, Any]] = None

        # Metrics
        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))
        self._avg_mention_type_loss = Average()
        self._avg_new_entity_loss = Average()
        self._avg_knowledge_graph_entity_loss = Average()
        self._new_mention_f1 = F1Measure(positive_label=1)
        self._kg_mention_f1 = F1Measure(positive_label=2)
        self._new_entity_accuracy = CategoricalAccuracy()
        self._new_entity_accuracy20 = CategoricalAccuracy(top_k=20)
        self._parent_ppl = Ppl()
        self._relation_ppl = Ppl()

        initializer(self)

    def sample(self,
               source: Dict[str, torch.Tensor],
               target: Dict[str, torch.Tensor],
               reset: torch.Tensor,
               metadata: Dict[str, Any],
               alias_copy_inds: torch.Tensor,
               shortlist: Dict[str, torch.Tensor] = None,
               **kwargs) -> Dict[str, Any]:  # **kwargs intended to eat the other fields if they are provided.
        """
        Sampling annotations for the generative model. Note that unlike forward, this function
        expects inputs from a **generative** dataset reader, not a **discriminative** one.
        """

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        logp = 0.0

        mask = get_text_field_mask(target).byte()
        # We encode the target tokens (**not** source) since the discriminative model makes
        # predictions on the current token, but the generative model expects labels for the
        # **next** (e.g. target) token!
        encoded, *_ = self._encode_source(target['tokens'])
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

        # Compute new mention logits
        mention_logits = self._fc_mention_type(encoded_token)
        mention_probs = F.softmax(mention_logits, dim=-1)
        mention_type = parallel_sample(mention_probs)
        mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log()
        mention_logp[~mask] = 0
        mention_logp = mention_logp.sum()

        # Compute entity logits
        new_entity_mask = mention_type.eq(1)
        new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist)
        if self._use_shortlist:
            # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up
            shortlist_mask = get_text_field_mask(shortlist)
            new_entity_probs = masked_softmax(new_entity_logits, shortlist_mask)
            shortlist_inds = torch.zeros_like(mention_type)
            # Some sequences may be full of padding in which case the shortlist
            # is empty
            not_just_padding = shortlist_mask.byte().any(-1)
            shortlist_inds[not_just_padding] = parallel_sample(new_entity_probs[not_just_padding])
            shortlist_inds[~new_entity_mask] = 0
            _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log()
            new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds)
        else:
            new_entity_logits = new_entity_logits
            # If not using shortlist, then samples are indexed w.r.t to the global vocab
            new_entity_probs = F.softmax(new_entity_logits, dim=-1)
            new_entity_samples = parallel_sample(new_entity_probs)
            _new_entity_logp = new_entity_probs.gather(-1, new_entity_samples.unsqueeze(-1)).log()
            shortlist_inds = None
        # Zero out masked tokens and non-new entity predictions
        _new_entity_logp[~mask] = 0
        _new_entity_logp[~new_entity_mask] = 0
        new_entity_logp = _new_entity_logp.sum()

        # Start filling in the entity ids
        entity_ids = torch.zeros_like(target['tokens'])
        entity_ids[new_entity_mask] = new_entity_samples[new_entity_mask]

        # ...UGH we also need the raw ids - remapping time
        raw_entity_ids = torch.zeros_like(target['tokens'])
        for *index, entity_id in nested_enumerate(entity_ids.tolist()):
            token = self.vocab.get_token_from_index(entity_id, 'entity_ids')
            raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids')
            raw_entity_ids[tuple(index)] = raw_entity_id

        # Derived mentions need to be computed sequentially.
        parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1)
        derived_entity_mask = mention_type.eq(2)
        derived_entity_logp = 0.0

        sequence_length = target['tokens'].shape[1]

        for i in range(sequence_length):

            current_mask = derived_entity_mask[:, i] & mask[:, i]

            # -------------------  SAMPLE PARENTS ---------------------

            # Update recent entities with **current** entity only
            current_entity_id = entity_ids[:, i].unsqueeze(1)
            candidate_ids, candidate_mask = self._recent_entities(current_entity_id)

            # If no mentions are derived, there is no point continuing after entities have been updated.
            if not current_mask.any():
                continue

            # Otherwise we proceed
            candidate_embeddings = self._entity_embedder(candidate_ids)

            # Compute logits w.r.t **current** hidden state only
            current_head_encoding = encoded_head[:, i].unsqueeze(1)
            selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2))
            selection_probs = masked_softmax(selection_logits, candidate_mask)

            # Only sample if there is at least one viable candidate (e.g. if a sampling distribution
            # has no probability mass we cannot sample from it). Return zero as the parent for
            # non-viable distributions.
            viable_candidate_mask = candidate_mask.any(-1).squeeze()
            _parent_ids = torch.zeros_like(current_entity_id)
            parent_logp = torch.zeros_like(current_entity_id, dtype=torch.float32)
            if viable_candidate_mask.any():
                viable_candidate_ids = candidate_ids[viable_candidate_mask]
                viable_candidate_probs = selection_probs[viable_candidate_mask]
                viable_parent_samples = parallel_sample(viable_candidate_probs)
                viable_logp = viable_candidate_probs.gather(-1, viable_parent_samples.unsqueeze(-1)).log()
                viable_parent_ids = viable_candidate_ids.gather(-1, viable_parent_samples)
                _parent_ids[viable_candidate_mask] = viable_parent_ids
                parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1)

            parent_ids[current_mask, i] = _parent_ids[current_mask]  # TODO: Double-check
            derived_entity_logp += parent_logp[current_mask].sum()

            # ---------------------- SAMPLE RELATION -----------------------------

            # Lookup sampled parent ids in the knowledge graph
            indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(_parent_ids)
            relation_embeddings = [self._relation_embedder(r) for r in relations_list]

            # Sample tail ids
            current_relation_encoding = encoded_relation[:, i].unsqueeze(1)
            _raw_tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            _tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            for index, relation_embedding, tail_id_lookup in zip(indices, relation_embeddings, tail_ids_list):
                # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss
                # code index has a slice. We don't need that here since there is always a
                # **single** parent.
                logits = torch.mv(relation_embedding, current_relation_encoding[index])
                # Convert to probability
                tail_probs = F.softmax(logits, dim=-1)
                # Sample
                tail_sample = torch.multinomial(tail_probs, 1)
                # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced
                # null parents to zero we shouldn't be accumulating probabilities for unused predictions.
                tail_logp = tail_probs.gather(-1, tail_sample).log()
                derived_entity_logp += tail_logp.sum()  # Sum is redundant, just need it to make logp a scalar

                # Map back to raw id
                raw_tail_id = tail_id_lookup[tail_sample]
                # Convert raw id to id
                tail_id_string = self.vocab.get_token_from_index(raw_tail_id.item(), 'raw_entity_ids')
                tail_id = self.vocab.get_token_index(tail_id_string, 'entity_ids')

                _raw_tail_ids[index[:-1]] = raw_tail_id
                _tail_ids[index[:-1]] = tail_id

            raw_entity_ids[current_mask, i] = _raw_tail_ids[current_mask]  # TODO: Double-check
            entity_ids[current_mask, i] = _tail_ids[current_mask]  # TODO: Double-check

            self._recent_entities.insert(_tail_ids, current_mask)

            # --------------------- CONTINUE MENTIONS ---------------------------------------
            continue_mask = mention_type[:, i].eq(3) & mask[:, i]
            if not current_mask.any() or i == 0:
                continue
            raw_entity_ids[continue_mask, i] = raw_entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            parent_ids[continue_mask, i] = parent_ids[continue_mask, i-1]
            if self._use_shortlist:
                shortlist_inds[continue_mask, i] = shortlist_inds[continue_mask, i-1]
            alias_copy_inds[continue_mask, i] = alias_copy_inds[continue_mask, i-1]

        # Lastly, because entities won't always match the true entity ids,
        # we need to zero out any alias copy ids that won't be valid.
        if 'raw_entity_ids' in kwargs:
            true_raw_entity_ids = kwargs['raw_entity_ids']['raw_entity_ids']
            invalid_id_mask = ~true_raw_entity_ids.eq(raw_entity_ids)
            alias_copy_inds[invalid_id_mask] = 0

        # Pass denotes fields that are passed directly from input to output.
        sample = {
            'source': source,  # Pass
            'target': target,  # Pass
            'reset': reset,  # Pass
            'metadata': metadata,  # Pass
            'mention_type': mention_type,
            'raw_entity_ids': {'raw_entity_ids': raw_entity_ids},
            'entity_ids': {'entity_ids': entity_ids},
            'parent_ids': {'entity_ids': parent_ids},
            'relations': {'relations': None},  # We aren't using them - eventually should remove entirely
            'shortlist': shortlist,  # Pass
            'shortlist_inds': shortlist_inds,
            'alias_copy_inds': alias_copy_inds
        }
        logp = mention_logp + new_entity_logp + derived_entity_logp
        return {'sample': sample, 'logp': logp}

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                metadata: List[Dict[str, Any]],
                mention_type: torch.Tensor = None,
                raw_entity_ids: Dict[str, torch.Tensor] = None,
                entity_ids: Dict[str, torch.Tensor] = None,
                parent_ids: Dict[str, torch.Tensor] = None,
                relations: Dict[str, torch.Tensor] = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                alias_database=alias_database,
                mention_type=mention_type,
                raw_entity_ids=raw_entity_ids,
                entity_ids=entity_ids,
                parent_ids=parent_ids,
                relations=relations,
                shortlist=shortlist,
                shortlist_inds=shortlist_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict

    def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor:

        # Extract and embed source tokens.
        source_embeddings = embedded_dropout(
            embed=self._token_embedder,
            words=source,
            dropout=self._dropoute if self.training else 0)
        source_embeddings = self._locked_dropout(source_embeddings, self._dropouti)

        # Encode.
        current_input = source_embeddings
        hidden_states = []
        for layer, rnn in enumerate(self.rnns):
            # Retrieve previous hidden state for layer.
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None
            # Forward-pass.
            output, hidden = rnn(current_input, prev_hidden)
            output = output.contiguous()
            # Update hidden state for layer.
            hidden = tuple(h.detach() for h in hidden)
            hidden_states.append(hidden)
            # Apply dropout.
            if layer == self._num_layers - 1:
                dropped_output = self._locked_dropout(output, self._dropout)
            else:
                dropped_output = self._locked_dropout(output, self._dropouth)
            current_input = dropped_output
        encoded = current_input

        alpha_loss = dropped_output.pow(2).mean()
        beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update state.
        self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)}

        return encoded, alpha_loss, beta_loss

    def _mention_type_loss(self,
                           encoded: torch.Tensor,
                           mention_type: torch.Tensor,
                           mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss for predicting whether or not the the next token will be part of an
        entity mention.
        """
        logits = self._fc_mention_type(encoded)
        mention_type_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask,
                                                               average='token')
        # if not self.training:
        self._new_mention_f1(predictions=logits,
                             gold_labels=mention_type,
                             mask=mask)
        self._kg_mention_f1(predictions=logits,
                            gold_labels=mention_type,
                            mask=mask)

        return mention_type_loss

    def _new_entity_logits(self,
                           encoded: torch.Tensor,
                           shortlist: torch.Tensor) -> torch.Tensor:
        if self._use_shortlist:
            # Embed the shortlist entries
            shortlist_embeddings = embedded_dropout(
                embed=self._entity_embedder,
                words=shortlist['entity_ids'],
                dropout=self._dropoute if self.training else 0)
            # Compute logits using inner product between the predicted entity embedding and the
            # embeddings of entities in the shortlist
            encodings = self._locked_dropout(encoded, self._dropout)
            logits = torch.bmm(encodings, shortlist_embeddings.transpose(1, 2))
        else:
            logits = self._fc_new_entity(encoded)
        return logits

    def _new_entity_loss(self,
                         encoded: torch.Tensor,
                         target_inds: torch.Tensor,
                         shortlist: torch.Tensor,
                         target_mask: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ==========
        target_inds : ``torch.Tensor``
            Either the shortlist inds if using shortlist, otherwise the target entity ids.
        """
        logits = self._new_entity_logits(encoded, shortlist)
        if self._use_shortlist:
            # Take masked softmax to get log probabilties and gather the targets.
            shortlist_mask = get_text_field_mask(shortlist)
            log_probs = masked_log_softmax(logits, shortlist_mask)
        else:
            logits = logits
            log_probs = F.log_softmax(logits, dim=-1)
            num_categories = log_probs.shape[-1]
            log_probs = log_probs.view(-1, num_categories)
            target_inds = target_inds.view(-1)
        target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1)

        mask = ~target_inds.eq(0)
        target_log_probs[~mask] = 0

        if mask.any():
            self._new_entity_accuracy(predictions=log_probs[mask],
                                    gold_labels=target_inds[mask])
            self._new_entity_accuracy20(predictions=log_probs[mask],
                                        gold_labels=target_inds[mask])

        return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

    def _parent_log_probs(self,
                          encoded_head: torch.Tensor,
                          entity_ids: torch.Tensor,
                          parent_ids: torch.Tensor) -> torch.Tensor:
        # Lookup recent entities (which are candidates for parents) and get their embeddings.
        candidate_ids, candidate_mask = self._recent_entities(entity_ids)
        logger.debug('Candidate ids shape: %s', candidate_ids.shape)
        candidate_embeddings = embedded_dropout(self._entity_embedder,
                                                words=candidate_ids,
                                                dropout=self._dropoute if self.training else 0)

        # Logits are computed using a general bilinear form that measures the similarity between
        # the projected hidden state and the embeddings of candidate entities
        encoded = self._locked_dropout(encoded_head, self._dropout)
        selection_logits = torch.bmm(encoded, candidate_embeddings.transpose(1, 2))

        # Get log probabilities using masked softmax (need to double check mask works properly).

        # shape: (batch_size, sequence_length, num_candidates)
        log_probs = masked_log_softmax(selection_logits, candidate_mask)

        # Now for the tricky part. We need to convert the parent ids to a mask that selects the
        # relevant probabilities from log_probs. To do this we need to align the candidates with
        # the parent ids, which can be achieved by an element-wise equality comparison. We also
        # need to ensure that null parents are not selected.

        # shape: (batch_size, sequence_length, num_parents, 1)
        _parent_ids = parent_ids.unsqueeze(-1)

        batch_size, num_candidates = candidate_ids.shape
        # shape: (batch_size, 1, 1, num_candidates)
        _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates)

        # shape: (batch_size, sequence_length, num_parents, num_candidates)
        is_parent = _parent_ids.eq(_candidate_ids)
        # shape: (batch_size, 1, 1, num_candidates)
        non_null = ~_candidate_ids.eq(0)

        # Since multiplication is addition in log-space, we can apply mask by adding its log (+
        # some small constant for numerical stability).
        mask = is_parent & non_null
        masked_log_probs = log_probs.unsqueeze(2) + (mask.float() + 1e-45).log()
        logger.debug('Masked log probs shape: %s', masked_log_probs.shape)

        # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would
        # be to marginalize it out. However, since our data is sparse (the last two dims are
        # essentially a delta function) this would add a lot of unneccesary terms to the computation graph.
        # To get around this we are going to try to use a gather.
        _, index = torch.max(mask, dim=-1, keepdim=True)
        target_log_probs = torch.gather(masked_log_probs, dim=-1, index=index).squeeze(-1)

        return target_log_probs

    def _relation_log_probs(self,
                            encoded_relation: torch.Tensor,
                            raw_entity_ids: torch.Tensor,
                            parent_ids: torch.Tensor) -> torch.Tensor:

        # Lookup edges out of parents
        indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(parent_ids)

        # Embed relations
        relation_embeddings = [self._relation_embedder(r) for r in relations_list]

        # Logits are computed using a general bi-linear form that measures the similarity between
        # the projected hidden state and the embeddings of relations
        encoded = self._locked_dropout(encoded_relation, self._dropout)

        # This is a little funky, but to avoid massive amounts of padding we are going to just
        # iterate over the relation and tail_id vectors one-by-one.
        # shape: (batch_size, sequence_length, num_parents, num_relations)
        target_log_probs = encoded.new_empty(*parent_ids.shape).fill_(math.log(1e-45))
        for index, parent_id, relation_embedding, tail_id in zip(indices, parent_ids_list, relation_embeddings, tail_ids_list):
            # First we compute the score for each relation w.r.t the current encoding, and convert
            # the scores to log-probabilities
            logits = torch.mv(relation_embedding, encoded[index[:-1]])
            logger.debug('Relation logits shape: %s', logits.shape)
            log_probs = F.log_softmax(logits, dim=-1)

            # Next we gather the log probs for edges with the correct tail entity and sum them up
            target_id = raw_entity_ids[index[:-1]]
            mask = tail_id.eq(target_id)
            relevant_log_probs = log_probs.masked_select(tail_id.eq(target_id))
            target_log_prob = torch.logsumexp(relevant_log_probs, dim=0)
            target_log_probs[index] = target_log_prob

        return target_log_probs

    def _knowledge_graph_entity_loss(self,
                                     encoded_head: torch.Tensor,
                                     encoded_relation: torch.Tensor,
                                     raw_entity_ids: torch.Tensor,
                                     entity_ids: torch.Tensor,
                                     parent_ids: torch.Tensor,
                                     target_mask: torch.Tensor) -> torch.Tensor:
        # First get the log probabilities of the parents and relations that lead to the current
        # entity.
        parent_log_probs = self._parent_log_probs(encoded_head, entity_ids, parent_ids)
        relation_log_probs = self._relation_log_probs(encoded_relation, raw_entity_ids, parent_ids)
        # Next take their product + marginalize
        combined_log_probs = parent_log_probs + relation_log_probs
        target_log_probs = torch.logsumexp(combined_log_probs, dim=-1)
        # Zero out any non-kg predictions
        mask = ~parent_ids.eq(0).all(dim=-1)
        target_log_probs = target_log_probs * mask.float()
        # If validating, measure ppl of the predictions:
        # if not self.training:
        self._parent_ppl(-torch.logsumexp(parent_log_probs, dim=-1)[mask].sum(), mask.float().sum())
        self._relation_ppl(-torch.logsumexp(relation_log_probs, dim=-1)[mask].sum(), mask.float().sum())
        # Lastly return the tokenwise average loss
        return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      mention_type: torch.Tensor,
                      raw_entity_ids: Dict[str, torch.Tensor],
                      entity_ids: Dict[str, torch.Tensor],
                      parent_ids: Dict[str, torch.Tensor],
                      relations: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get the token mask and extract indexed text fields.
        # shape: (batch_size, sequence_length)
        target_mask = get_text_field_mask(source)
        source = source['tokens']
        raw_entity_ids = raw_entity_ids['raw_entity_ids']
        entity_ids = entity_ids['entity_ids']
        parent_ids = parent_ids['entity_ids']
        relations = relations['relations']

        logger.debug('Source & Target shape: %s', source.shape)
        logger.debug('Entity ids shape: %s', entity_ids.shape)
        logger.debug('Relations & Parent ids shape: %s', relations.shape)
        logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
        # Embed source tokens.
        # shape: (batch_size, sequence_length, embedding_dim)
        encoded, alpha_loss, beta_loss = self._encode_source(source)
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

        # Predict whether or not the next token will be an entity mention, and if so which type.
        mention_type_loss = self._mention_type_loss(encoded_token, mention_type, target_mask)
        self._avg_mention_type_loss(float(mention_type_loss))

        # For new mentions, predict which entity (among those in the supplied shortlist) will be
        # mentioned.
        if self._use_shortlist:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    shortlist_inds,
                                                    shortlist,
                                                    target_mask)
        else:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    entity_ids,
                                                    None,
                                                    target_mask)

        self._avg_new_entity_loss(float(new_entity_loss))

        # For derived mentions, first predict which parent(s) to expand...
        knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head,
                                                                        encoded_relation,
                                                                        raw_entity_ids,
                                                                        entity_ids,
                                                                        parent_ids,
                                                                        target_mask)
        self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss))

        # Compute total loss
        loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * alpha_loss
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * beta_loss

        return {'loss': loss}

    @overrides
    def train(self, mode=True):
        # TODO: This is a temporary hack to ensure that the internal state resets when the model
        # switches from training to evaluation. The complication arises from potentially differing
        # batch sizes (e.g. the `reset` tensor will not be the right size).
        # In future implementations this should be handled more robustly.
        super().train(mode)
        self._state = None

    @overrides
    def eval(self):
        # TODO: See train.
        super().eval()
        self._state = None

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        out =  {
            'type': self._avg_mention_type_loss.get_metric(reset),
            'new': self._avg_new_entity_loss.get_metric(reset),
            'kg': self._avg_knowledge_graph_entity_loss.get_metric(reset),
        }
        # if not self.training:
        p, r, f = self._new_mention_f1.get_metric(reset)
        out['new_p'] = p
        out['new_r'] = r
        out['new_f1'] = f
        p, r, f = self._kg_mention_f1.get_metric(reset)
        out['kg_p'] = p
        out['kg_r'] = r
        out['kg_f1'] = f
        out['new_ent_acc'] = self._new_entity_accuracy.get_metric(reset)
        out['new_ent_acc_20'] = self._new_entity_accuracy20.get_metric(reset)
        out['parent_ppl'] = self._parent_ppl.get_metric(reset)
        out['relation_ppl'] = self._relation_ppl.get_metric(reset)
        return out
class WikiTablesSemanticParser(Model):
    u"""
    A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question,
    and produces a logical form that answers the question when executed over the table.  The
    logical form is generated by a `type-constrained`, `transition-based` parser. This is an
    abstract class that defines most of the functionality related to the transition-based parser. It
    does not contain the implementation for actually training the parser. You may want to train it
    using a learning-to-search algorithm, in which case you will want to use
    ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the
    correct denotation, you will want to use ``WikiTablesMmlSemanticParser``.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables.
    """

    # pylint: disable=abstract-method
    def __init__(self,
                 vocab,
                 question_embedder,
                 action_embedding_dim,
                 encoder,
                 entity_encoder,
                 max_decoding_steps,
                 use_neighbor_similarity_for_linking=False,
                 dropout=0.0,
                 num_linking_features=10,
                 rule_namespace=u'rule_labels',
                 tables_directory=u'/wikitables/'):
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions,
                                        embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(),
                               question_embedder.get_output_dim(),
                               u"entity word average embedding dim",
                               u"question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types,
                                            self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

    def _get_initial_state_and_scores(self,
                                      question,
                                      table,
                                      world,
                                      actions,
                                      example_lisp_string=None,
                                      add_world_to_initial_state=False,
                                      checklist_states=None):
        u"""
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table[u'text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                u'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings +
                                       projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table[u'linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(
            batch_indices=range(batch_size),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            world=initial_state_world,
            example_lisp_string=example_lisp_string,
            checklist_state=checklist_states,
            debug_info=None)
        return {
            u"initial_state": initial_state,
            u"linking_scores": linking_scores,
            u"feature_scores": feature_scores,
            u"similarity_scores": question_entity_similarity_max_score
        }

    @staticmethod
    def _get_neighbor_indices(worlds, num_entities, tensor):
        u"""
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = dict(
                (entity, i) for i, entity in enumerate(entities))
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [
                    entity2index[n] for n in entity2neighbors[entity]
                ]
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors,
                                                num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(
                neighbor_indexes, num_entities, lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        return tensor.new_tensor(batch_neighbors, dtype=torch.long)

    @staticmethod
    def _get_type_vector(worlds, num_entities, tensor):
        u"""
        Produces the one hot encoding for each entity's type. In addition,
        a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                                   [0, 0, 0, 1]]
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith(u'fb:cell'):
                    entity_type = 1
                elif entity.startswith(u'fb:part'):
                    entity_type = 2
                elif entity.startswith(u'fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(one_hot_vectors[entity_type])

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities,
                                            lambda: [0, 0, 0, 0])
            batch_types.append(padded)
        return tensor.new_tensor(batch_types), entity_types

    def _get_linking_probabilities(self, worlds, linking_scores, question_mask,
                                   entity_type_dict):
        u"""
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted, targets):
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    #overrides
    def get_metrics(self, reset=False):
        u"""
        We track three metrics here:

            1. dpd_acc, which is the percentage of the time that our best output action sequence is
            in the set of action sequences provided by DPD.  This is an easy-to-compute lower bound
            on denotation accuracy for the set of examples where we actually have DPD output.  We
            only score dpd_acc on that subset.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that has DPD output (make sure
            you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        return {
            u'dpd_acc': self._action_sequence_accuracy.get_metric(reset),
            u'denotation_acc': self._denotation_accuracy.get_metric(reset),
            u'lf_percent': self._has_logical_form.get_metric(reset),
        }

    @staticmethod
    def _create_grammar_state(world, possible_actions):
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_string = action[0]
            action_mapping[action_string] = i
        translated_valid_actions = {}
        for key, action_strings in list(valid_actions.items()):
            translated_valid_actions[key] = [
                action_mapping[action_string]
                for action_string in action_strings
            ]
        return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                            action_mapping, type_declaration.is_nonterminal)

    def _embed_actions(self, actions):
        u"""
        Given all of the possible actions for all batch instances, produce an embedding for them.
        There will be significant overlap in this list, as the production rules from the grammar
        are shared across all batch instances.  Our returned tensor has an embedding for each
        `unique` action, so we also need to return a mapping from the original ``(batch_index,
        action_index)`` to our new ``global_action_index``, so that we can get the right action
        embedding during decoding.

        Returns
        -------
        action_embeddings : ``torch.Tensor``
            Has shape ``(num_unique_actions, action_embedding_dim)``.
        output_action_embeddings : ``torch.Tensor``
            Has shape ``(num_unique_actions, action_embedding_dim)``.
        action_biases : ``torch.Tensor``
            Has shape ``(num_unique_actions, 1)``.
        action_map : ``Dict[Tuple[int, int], int]``
            Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in
            the ``action_embeddings`` tensor.  All non-embeddable actions get mapped to `-1` here.
        """
        # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it
        # needs to be.  We used to embed actions differently (using some crazy ideas about
        # embedding the LHS and RHS separately); we could probably get away with simplifying things
        # further now that we're just doing a simple embedding for global actions.  But I'm leaving
        # it like this for now to have a minimal change to go from the LHS/RHS embedding to a
        # single action embedding.
        embedded_actions = self._action_embedder.weight
        output_embedded_actions = self._output_action_embedder.weight
        action_biases = self._action_biases.weight

        # Now we just need to make a map from `(batch_index, action_index)` to
        # `global_action_index`.  global_action_ids has the list of all unique actions; here we're
        # going over all of the actions for each batch instance so we can map them to the global
        # action ids.
        action_vocab = self.vocab.get_token_to_index_vocabulary(
            self._rule_namespace)
        action_map = {}
        for batch_index, instance_actions in enumerate(actions):
            for action_index, action in enumerate(instance_actions):
                if not action[0]:
                    # This rule is padding.
                    continue
                global_action_id = action_vocab.get(action[0], -1)
                action_map[(batch_index, action_index)] = global_action_id
        return embedded_actions, output_embedded_actions, action_biases, action_map

    @staticmethod
    def _map_entity_productions(linking_scores, worlds, actions):
        u"""
        Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``.
        That is, some actions correspond to terminal productions of entities from our table.  We
        need to find those actions and map them to their corresponding entity indices, where the
        entity index is its position in the list of entities returned by the ``world``.  This list
        is what defines the second dimension of the ``linking_scores`` tensor, so we can use this
        index to look up linking scores for each action in that tensor.

        For easier processing later, the mapping that we return is `flattened` - we really want to
        map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to
        have to use the result of this mapping to do ``index_selects`` on the ``linking_scores``
        tensor.  You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to
        have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices
        into this flattened tensor.

        Parameters
        ----------
        linking_scores : ``torch.Tensor``
            A tensor representing linking scores between each table entity and each question token.
            Has shape ``(batch_size, num_entities, num_question_tokens)``.
        worlds : ``List[WikiTablesWorld]``
            The ``World`` for each batch instance.  The ``World`` contains a reference to the
            ``TableKnowledgeGraph`` that defines the set of entities in the linking.
        actions : ``List[List[ProductionRuleArray]]``
            The list of possible actions for each batch instance.  Our action indices are defined
            in terms of this list, so we'll find entity productions in this list and map them to
            entity indices from the entity list we get from the ``World``.

        Returns
        -------
        flattened_linking_scores : ``torch.Tensor``
            A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities,
            num_question_tokens)``.
        actions_to_entities : ``Dict[Tuple[int, int], int]``
            A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``,
            representing which action indices correspond to which entity indices in the returned
            ``flattened_linking_scores`` tensor.
        """
        batch_size, num_entities, num_question_tokens = linking_scores.size()
        entity_map = {}
        for batch_index, world in enumerate(worlds):
            for entity_index, entity in enumerate(world.table_graph.entities):
                entity_map[(
                    batch_index,
                    entity)] = batch_index * num_entities + entity_index
        actions_to_entities = {}
        for batch_index, action_list in enumerate(actions):
            for action_index, action in enumerate(action_list):
                if not action[0]:
                    # This action is padding.
                    continue
                _, production = action[0].split(u' -> ')
                entity_index = entity_map.get((batch_index, production), None)
                if entity_index is not None:
                    actions_to_entities[(batch_index,
                                         action_index)] = entity_index
        flattened_linking_scores = linking_scores.view(
            batch_size * num_entities, num_question_tokens)
        return flattened_linking_scores, actions_to_entities

    #overrides
    def decode(self, output_dict):
        u"""
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict[u'action_mapping']
        best_actions = output_dict[u"best_action_sequence"]
        debug_infos = output_dict[u'debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                izip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in izip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info[u'predicted_action'] = predicted_action
                considered_actions = action_debug_info[u'considered_actions']
                probabilities = action_debug_info[u'probabilities']
                actions = []
                for action, probability in izip(considered_actions,
                                                probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = izip(*actions)
                action_info[u'considered_actions'] = considered_actions
                action_info[u'action_probabilities'] = probabilities
                action_info[u'question_attention'] = action_debug_info.get(
                    u'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict[u"predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 16
0
class WikiTablesErmSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search
    for logical forms that yield the correct denotations.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    decoder_beam_size : ``int``
        Beam size to be used by the ExpectedRiskMinimization algorithm.
    decoder_num_finished_states : ``int``
        Number of finished states for which costs will be computed by the ExpectedRiskMinimization
        algorithm.
    max_decoding_steps : ``int``
        Maximum number of steps the decoder should take before giving up. Used both during training
        and evaluation. Passed to super class.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.  Passed to super class.
    normalize_beam_score_by_length : ``bool``, optional (default=False)
        Should we normalize the log-probabilities by length before renormalizing the beam? This was
        shown to work better for NML by Edunov et al., but that many not be the case for semantic
        parsing.
    checklist_cost_weight : ``float``, optional (default=0.6)
        Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
        weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
        denotation accuracy.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    mml_model_file : ``str``, optional (default=None)
        If you want to initialize this model using weights from another model trained using MML,
        pass the path to the ``model.tar.gz`` file of that model here.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 attention: Attention,
                 decoder_beam_size: int,
                 decoder_num_finished_states: int,
                 max_decoding_steps: int,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 mml_model_file: str = None) -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         add_action_bias=add_action_bias,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace)
        # Not sure why mypy needs a type annotation for this!
        self._decoder_trainer: ExpectedRiskMinimization = \
                ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=self._max_decoding_steps,
                                         max_num_finished_states=decoder_num_finished_states)
        self._decoder_step = LinkingCoverageTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            add_action_bias=self._add_action_bias,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._agenda_coverage = Average()
        # We don't need a separate beam search since the trainer does that already. But we're defining one just to
        # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch``
        # class) in the demo. We'll use this only at test time.
        self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size)
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if mml_model_file is not None:
            if os.path.isfile(mml_model_file):
                archive = load_archive(mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(
                    "MML model file for initializing weights is passed, but does not exist."
                    " This is fine if you're just decoding.")

    def _initialize_weights_from_archive(self, archive: Archive) -> None:
        logger.info("Initializing weights from MML model.")
        model_parameters = dict(self.named_parameters())
        archived_parameters = dict(archive.model.named_parameters())
        question_embedder_weight = "_question_embedder.token_embedder_tokens.weight"
        if question_embedder_weight not in archived_parameters or \
           question_embedder_weight not in model_parameters:
            raise RuntimeError(
                "When initializing model weights from an MML model, we need "
                "the question embedder to be a TokenEmbedder using namespace called "
                "tokens.")
        for name, weights in archived_parameters.items():
            if name in model_parameters:
                if name == question_embedder_weight:
                    # The shapes of embedding weights will most likely differ between the two models
                    # because the vocabularies will most likely be different. We will get a mapping
                    # of indices from this model's token indices to the archived model's and copy
                    # the tensor accordingly.
                    vocab_index_mapping = self._get_vocab_index_mapping(
                        archive.model.vocab)
                    archived_embedding_weights = weights.data
                    new_weights = model_parameters[name].data.clone()
                    for index, archived_index in vocab_index_mapping:
                        new_weights[index] = archived_embedding_weights[
                            archived_index]
                    logger.info("Copied embeddings of %d out of %d tokens",
                                len(vocab_index_mapping),
                                new_weights.size()[0])
                else:
                    new_weights = weights.data
                logger.info("Copying parameter %s", name)
                model_parameters[name].data.copy_(new_weights)

    def _get_vocab_index_mapping(
            self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]:
        vocab_index_mapping: List[Tuple[int, int]] = []
        for index in range(self.vocab.get_vocab_size(namespace='tokens')):
            token = self.vocab.get_token_from_index(index=index,
                                                    namespace='tokens')
            archived_token_index = archived_vocab.get_token_index(
                token, namespace='tokens')
            # Checking if we got the UNK token index, because we don't want all new token
            # representations initialized to UNK token's representation. We do that by checking if
            # the two tokens are the same. They will not be if the token at the archived index is
            # UNK.
            if archived_vocab.get_token_from_index(
                    archived_token_index, namespace="tokens") == token:
                vocab_index_mapping.append((index, archived_token_index))
        return vocab_index_mapping

    @overrides
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesLanguage],
            actions: List[List[ProductionRule]],
            agenda: torch.LongTensor,
            target_values: List[List[str]] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesLanguage]``
            We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
            Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``world`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        agenda : ``torch.LongTensor``
            Agenda vectors that the checklist vectors will be compared against to compute the checklist
            cost.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        metadata : ``List[Dict[str, Any]]``, optional (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' field.
        """
        batch_size = list(question.values())[0].size(0)
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistStatelet(terminal_actions=terminal_actions,
                                  checklist_target=checklist_target,
                                  checklist_mask=checklist_mask,
                                  checklist=initial_checklist))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),  # type: ignore
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=target_values,
            debug_info=None)

        if target_values is not None:
            logger.warning(f"TARGET VALUES: {target_values}")
            trainer_outputs = self._decoder_trainer.decode(
                initial_state,  # type: ignore
                self._decoder_step,
                partial(self._get_state_cost, world))
            outputs.update(trainer_outputs)
        else:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             target_values, metadata, outputs)
        return outputs

    @staticmethod
    def _get_checklist_info(
        agenda: torch.LongTensor, all_actions: List[ProductionRule],
        terminal_productions: Set[str], max_num_terminals: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = {
            int(x)
            for x in agenda.squeeze(0).detach().cpu().numpy()
        }
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list,
                                             dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask

    def _get_state_cost(self, worlds: List[WikiTablesLanguage],
                        state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(),
                                        min=0.0)
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [
            state.possible_actions[batch_index][i][0] for i in action_history
        ]
        target_values = state.extras[batch_index]
        evaluation = False
        executor_logger = \
                logging.getLogger('allennlp.semparse.domain_languages.wikitables_language')
        executor_logger.setLevel(logging.ERROR)
        evaluation = world.evaluate_action_sequence(action_strings,
                                                    target_values)
        if evaluation:
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        The base class returns a dict with dpd accuracy, denotation accuracy, and logical form
        percentage metrics. We add the agenda coverage metric here.
        """
        metrics = super().get_metrics(reset)
        metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset)
        return metrics
Ejemplo n.º 17
0
class Seq2SeqClaimRank(Model):
    """
    A ``Seq2SeqClaimRank`` model. This model is intended to be trained with a multi-instance
    learning objective that simultaneously tries to:
        - Decode the given post modifier (e.g. the ``target`` sequence).
        - Ensure that the model is attending to the proper claims during decoding (which are
        identified by the ``labels`` variable).
    The basic architecture is a seq2seq model with attention where the input sequence is the source
    sentence (without post-modifier), and the output sequence is the post-modifier. The main
    difference is that instead of performing attention over the input sequence, attention is
    performed over a collection of claims.

    Parameters
    ==========
    text_field_embedder : ``TextFieldEmbedder``
        Embeds words in the source sentence / claims.
    sentence_encoder : ``Seq2VecEncoder``
        Encodes the entire source sentence into a single vector.
    claim_encoder : ``Seq2SeqEncoder``
        Encodes each claim into a single vector.
    attention : ``Attention``
        Type of attention mechanism used.
        WARNING: Do not normalize attention scores, and make sure to use a
        sigmoid activation. Otherwise the claim ranking loss will not work
        properly!
    max_steps : ``int``
        Maximum number of decoding steps. Default: 100 (same as ONMT).
    beam_size: ``int``
        Beam size used during evaluation. Default: 5 (same as ONMT).
    beta: ``float``
        Weight of attention loss term.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sentence_encoder: Seq2VecEncoder,
                 claim_encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_steps: int = 100,
                 beam_size: int = 5,
                 beta: float = 1.0) -> None:
        super(Seq2SeqClaimRank, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.sentence_encoder = sentence_encoder
        self.claim_encoder = TimeDistributed(claim_encoder)  # Handles additional sequence dim
        self.claim_encoder_dim = claim_encoder.get_output_dim()
        self.attention = attention
        self.decoder_embedding_dim = text_field_embedder.get_output_dim()
        self.max_steps = max_steps
        self.beam_size = beam_size
        self.beta = beta

        # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim)

        # Since we are using the sentence encoding as the initial hidden state to the decoder, the
        # decoder hidden dim must match the sentence encoder hidden dim.
        self.decoder_output_dim = sentence_encoder.get_output_dim()
        self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim,
                                                self.decoder_output_dim)
        self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim,
                                                self.decoder_output_dim)

        # When projecting out we will use attention to combine claim embeddings into a single
        # context embedding, this will be concatenated with the decoder cell output before being
        # fed to the projection layer. Hence the expected input size is:
        #   decoder output dim + claim encoder output dim
        projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim
        self.output_projection_layer = torch.nn.Linear(projection_input_dim,
                                                       vocab.get_vocab_size())

        self._start_index = self.vocab.get_token_index('<s>')
        self._end_index = self.vocab.get_token_index('</s>')

        self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size)
        pad_index = vocab.get_token_index(vocab._padding_token)
        self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index})
        self.avg_reconstruction_loss = Average()
        self.avg_claim_scoring_loss = Average()

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        output_projections, _, state = self._prepare_output_projections(last_predictions, state)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)
        return class_log_probabilities, state

    @overrides
    def forward(self,
                inputs: Dict[str, torch.LongTensor],
                claims: Dict[str, torch.LongTensor],
                targets: Dict[str, torch.LongTensor] = None,
                labels: torch.Tensor = None) -> torch.Tensor:
        """Forward pass of the model + decoder logic.

        Parameters
        ----------
        inputs : ``Dict[str, torch.LongTensor]``
            Output of `TextField.as_array()` from the `input` field.
        claims : ``Dict[str, torch.LongTensor]``
            Output of `ListField.as_array()` from the `claims` field.
        targets : ``Dict[str, torch.LongTensor]``
            Output of `TextField.as_array()` from the `target` field.
            Only expected during training and validation.
        labels : ``torch.Tensor``
            Output of `LabelField.as_array()` from the `labels` field, indicating which claims were
            used.
            Only expected during training and validation.

        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing loss tensor and decoder outputs.
        """

        # Obtain an encoding for each input sentence (e.g. the contexts)
        input_mask = util.get_text_field_mask(inputs)
        input_word_embeddings = self.text_field_embedder(inputs)
        input_encodings = self.sentence_encoder(input_word_embeddings, input_mask)

        # Next we encode claims. Note that here we have two additional sequence dimensions (since
        # there are multiple claims per instance, and we want to apply attention at the word
        # level). To deal with this we need to set `num_wrapping_dims=1` for the embedder, and make
        # the claim encoder TimeDistributed.
        claim_mask = util.get_text_field_mask(claims, num_wrapping_dims=1)
        claim_word_embeddings = self.text_field_embedder(claims, num_wrapping_dims=1)
        claim_encodings = self.claim_encoder(claim_word_embeddings, claim_mask)

        # Package the encoder outputs into a state dictionary.
        state = {
            'input_mask': input_mask,
            'input_encodings': input_encodings,
            'claim_mask': claim_mask,
            'claim_encodings': claim_encodings
        }

        # If ``target`` (the post-modifier) and ``labels`` (indicator of which claims are used) are
        # provided then we use them to compute loss.
        if (targets is not None) and (labels is not None):
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(state, targets, labels)
        else:
            output_dict = {}

        # If model is not training, then we perform beam search for decoding to obtain higher
        # quality outputs.
        if not self.training:
            # Perform beam search
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            # Compute BLEU
            top_k_predictions = output_dict['predictions']
            best_predictions = top_k_predictions[:, 0, :]
            self.bleu(best_predictions, targets['tokens'])

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.
        """
        predicted_indices = output_dict['predictions']
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x) for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _init_decoder_state(self,
                            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Adds fields to the state required to initialize the decoder."""

        batch_size = state['input_mask'].shape[0]

        # First decoder layer gets jack (trying to approximate the structure in
        # opennmt's graphic
        state['decoder_0_h'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)
        state['decoder_0_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)

        # Initialize LSTM hidden state (e.g. h_0) with output of the sentence encoder.
        state['decoder_1_h'] = state['input_encodings']
        # Initialize LSTM context state (e.g. c_0) with zeros.
        state['decoder_1_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)
        # Initialize previous context.
        state['prev_context'] = state['input_encodings'].new_zeros(batch_size, self.claim_encoder_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      targets: Dict[str, torch.Tensor],
                      labels: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Compute loss using greedy decoding."""
        batch_size = state['input_mask'].shape[0]
        target_tokens = targets['tokens']
        num_decoding_steps = target_tokens.shape[1] - 1

        # Greedy decoding phase
        output_logit_list = []
        attention_logit_list = []
        select_idx_list = []
        for timestep in range(num_decoding_steps):
            # Feed target sequence as input
            decoder_input = target_tokens[:, timestep]
            output_logits, attention_logits, state = self._prepare_output_projections(decoder_input, state)
            # Store output and attention logits
            output_logit_list.append(output_logits.unsqueeze(1))
            attention_logit_list.append(attention_logits.unsqueeze(1))

        # Compute reconstruction loss
        output_logit_tensor = torch.cat(output_logit_list, dim=1)
        relevant_target_tokens = target_tokens[:, 1:].contiguous()
        target_mask = util.get_text_field_mask(targets)[:, 1:].contiguous()
        reconstruction_loss = util.sequence_cross_entropy_with_logits(output_logit_tensor,
                                                                      relevant_target_tokens,
                                                                      target_mask)

        # Compute claim scoring loss. A loss is computed between **each** attention vector and the
        # true label. In order for that to work we need to:
        #   a. Tile the source labels (so that they are copied for each word)
        #   b. Mask out padding tokens - this requires taking the outer-product of the target mask
        #       and the claim mask
        attention_logit_tensor = torch.cat(attention_logit_list, dim=1)
        claim_level_mask = (state['claim_mask'].sum(-1) > 0).long()
        attention_mask = target_mask.unsqueeze(-1) * claim_level_mask.unsqueeze(1)
        labels = labels.unsqueeze(1).repeat(1, num_decoding_steps, 1).float()
        claim_scoring_loss = F.binary_cross_entropy_with_logits(attention_logit_tensor, labels, reduction='none')
        claim_scoring_loss *= attention_mask.float()  # Apply mask

        # We want to apply 'batch' reduction (as is done in `sequence_cross_entropy...` which
        # entails averaging over each dimension.
        denom = attention_mask
        for i in range(3):
            denom = denom.sum(-1)
            claim_scoring_loss =  claim_scoring_loss.sum(-1) / (denom.float() + 1e-13)
            denom = (denom > 0)

        total_loss = reconstruction_loss + self.beta * claim_scoring_loss

        # Update metrics
        self.avg_reconstruction_loss(reconstruction_loss)
        self.avg_claim_scoring_loss(claim_scoring_loss)

        output_dict =  {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "claim_scoring_loss": claim_scoring_loss,
            "attention_logits": attention_logit_tensor
        }

        return output_dict

    def _prepare_output_projections(self,
                                    decoder_input: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        # Embed decoder input
        decoder_word_embeddings = self.text_field_embedder({'tokens': decoder_input})

        # Concat with previous context
        concat = torch.cat((decoder_word_embeddings, state['prev_context']), dim=-1)

        # Run forward pass of decoder RNN
        decoder_0_h, decoder_0_c = self.decoder_0_cell(concat, (state['decoder_0_h'], state['decoder_0_c']))
        decoder_1_h, decoder_1_c = self.decoder_1_cell(decoder_0_h, (state['decoder_1_h'], state['decoder_1_c']))
        state['decoder_0_h'] = decoder_0_h
        state['decoder_0_c'] = decoder_0_c
        state['decoder_1_h'] = decoder_1_h
        state['decoder_1_c'] = decoder_1_c

        # Compute attention and get context embedding. We get an attention score for each word in
        # each claim. Then we sum up scores to get a claim level score (so we can use overlap as
        # supervision).
        claim_encodings = state['claim_encodings']
        claim_mask = state['claim_mask']
        batch_size, n_claims, claim_length, dim = claim_encodings.shape

        flattened_claim_encodings = claim_encodings.view(batch_size, -1, dim)
        flattened_claim_mask = claim_mask.view(batch_size, -1)
        flattened_attention_logits = self.attention(decoder_1_h, flattened_claim_encodings, flattened_claim_mask)
        attention_logits = flattened_attention_logits.view(batch_size, n_claims, claim_length)

        # Now get claim level encodings by summing word level attention.
        word_level_attention = util.masked_softmax(attention_logits, claim_mask)
        claim_encodings = util.weighted_sum(claim_encodings, word_level_attention)

        # If not training, get max attention word to replace unk
        if not self.training:
            max_word = word_level_attention.argmax(dim=-1, keepdim=True)
            gathered = word_level_attention.gather(dim=-1, index=max_word)
            max_claim = gathered.squeeze().argmax(dim=-1, keepdim=True)
            max_word = max_word.squeeze().gather(dim=1, index=max_claim)
            select_idx = torch.cat((max_claim, max_word), dim=-1)
        else:
            select_idx = None

        # We compute our context directly from the claim word embeddings
        claim_mask = (claim_mask.sum(-1) > 0).float()
        attention_logits = attention_logits.sum(-1)
        attention_weights = torch.sigmoid(attention_logits) * claim_mask
        normalized_attention_weights = attention_weights / (attention_weights.sum(-1, True) + 1e-13)
        context_embedding = util.weighted_sum(claim_encodings, normalized_attention_weights)
        state['prev_context'] = context_embedding

        # Concatenate RNN output w/ context vector and feed through final hidden layer
        projection_input = torch.cat((decoder_1_h, context_embedding), dim=-1)
        output_logits = self.output_projection_layer(projection_input)

        return output_logits, attention_logits, state

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state['input_mask'].size()[0]
        start_predictions = state['input_mask'].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self.beam_search.search(
                start_predictions, state, self.take_step)

        output_dict = {
                "class_log_probabilities": log_probabilities,
                "predictions": all_top_k_predictions,
        }
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {
            'recon': self.avg_reconstruction_loss.get_metric(reset=reset).data.item(),
            'claim': self.avg_claim_scoring_loss.get_metric(reset=reset).data.item()
        }
        # Only update BLEU score during validation and evaluation
        if not self.training:
            all_metrics.update(self.bleu.get_metric(reset=reset))
        return all_metrics
Ejemplo n.º 18
0
class WikiTablesSemanticParser(Model):
    """
    A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question,
    and produces a logical form that answers the question when executed over the table.  The
    logical form is generated by a `type-constrained`, `transition-based` parser. This is an
    abstract class that defines most of the functionality related to the transition-based parser. It
    does not contain the implementation for actually training the parser. You may want to train it
    using a learning-to-search algorithm, in which case you will want to use
    ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the
    correct denotation, you will want to use ``WikiTablesMmlSemanticParser``.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        question_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        entity_encoder: Seq2VecEncoder,
        max_decoding_steps: int,
        add_action_bias: bool = True,
        use_neighbor_similarity_for_linking: bool = False,
        dropout: float = 0.0,
        num_linking_features: int = 10,
        rule_namespace: str = "rule_labels",
    ) -> None:
        super().__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,
                                            embedding_dim=1)
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        check_dimensions_match(
            entity_encoder.get_output_dim(),
            question_embedder.get_output_dim(),
            "entity word average embedding dim",
            "question embedding dim",
        )

        self._num_entity_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._entity_type_encoder_embedding = Embedding(
            num_embeddings=self._num_entity_types,
            embedding_dim=self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            num_embeddings=self._num_entity_types,
            embedding_dim=action_embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

    def _get_initial_rnn_and_grammar_state(
        self,
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRuleArray]],
        outputs: Dict[str, Any],
    ) -> Tuple[List[RnnStatelet], List[GrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and GrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table["text"]
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question)
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        if neighbor_indices is not None:
            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    "ignored": {
                        "ignored": neighbor_indices + 1
                    }
                },
                num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2),
        )

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table["linking"]

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = (projected_question_entity_similarity +
                              projected_question_neighbor_similarity)

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_question,
                    encoder_output_list,
                    question_mask_list,
                ))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs["linking_scores"] = linking_scores
            if feature_scores is not None:
                outputs["feature_scores"] = feature_scores
            outputs["similarity_scores"] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state

    @staticmethod
    def _get_neighbor_indices(worlds: List[WikiTablesLanguage],
                              num_entities: int,
                              tensor: torch.Tensor) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesLanguage]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch
        have no neighbors, None will be returned.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        no_entities_have_neighbors = True
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [
                    entity2index[n] for n in entity2neighbors[entity]
                ]
                if entity_neighbors:
                    no_entities_have_neighbors = False
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors,
                                                num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(
                neighbor_indexes, num_entities, lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        # It is possible that none of the entities has any neighbors, since our definition of the
        # knowledge graph allows it when no entities or numbers were extracted from the question.
        if no_entities_have_neighbors:
            return None
        return tensor.new_tensor(batch_neighbors, dtype=torch.long)

    @staticmethod
    def _get_type_vector(
            worlds: List[WikiTablesLanguage], num_entities: int,
            tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesLanguage]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then date columns, then number columns, strings, and
                # string columns, in that order, because our entities are going to be sorted.  We do
                # a split by type and then a merge later, and it relies on this sorting.
                if entity.startswith("date_column:"):
                    entity_type = 1
                elif entity.startswith("number_column:"):
                    entity_type = 2
                elif entity.startswith("string:"):
                    entity_type = 3
                elif entity.startswith("string_column:"):
                    entity_type = 4
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(
        self,
        worlds: List[WikiTablesLanguage],
        linking_scores: torch.FloatTensor,
        question_mask: torch.LongTensor,
        entity_type_dict: Dict[int, int],
    ) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesLanguage]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "date_column:", followed by "number_column:", "string:", and "string_column:".
            # This is not a great assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. lf_retrieval_acc, which is the percentage of the time that our best output action
            sequence is in the set of action sequences provided by offline search.  This is an
            easy-to-compute lower bound on denotation accuracy for the set of examples where we
            actually have offline output.  We only score lf_retrieval_acc on that subset.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that has DPD output (make sure
            you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        return {
            "lf_retrieval_acc":
            self._action_sequence_accuracy.get_metric(reset),
            "denotation_acc": self._denotation_accuracy.get_metric(reset),
            "lf_percent": self._has_logical_form.get_metric(reset),
        }

    def _create_grammar_state(
        self,
        world: WikiTablesLanguage,
        possible_actions: List[ProductionRuleArray],
        linking_scores: torch.Tensor,
        entity_types: torch.Tensor,
    ) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesLanguage``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_nonterminal_productions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions if any.
            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0)
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                if self._add_action_bias:
                    global_action_biases = self._action_biases(
                        global_action_tensor)
                    global_input_embeddings = torch.cat(
                        [global_input_embeddings, global_action_biases],
                        dim=-1)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]["global"] = (
                    global_input_embeddings,
                    global_output_embeddings,
                    list(global_action_ids),
                )

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(" -> ")[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]["linked"] = (
                    entity_linking_scores,
                    entity_type_embeddings,
                    list(linked_action_ids),
                )
        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               world.is_nonterminal)

    def _compute_validation_outputs(
        self,
        actions: List[List[ProductionRuleArray]],
        best_final_states: Mapping[int, Sequence[GrammarBasedState]],
        world: List[WikiTablesLanguage],
        target_list: List[List[str]],
        metadata: List[Dict[str, Any]],
        outputs: Dict[str, Any],
    ) -> None:
        """
        Does common things for validation time: computing logical form accuracy (which is expensive
        and unnecessary during training), adding visualization info to the output dictionary, etc.

        This doesn't return anything; instead it `modifies` the given ``outputs`` dictionary, and
        calls metrics on ``self``.
        """
        batch_size = len(actions)
        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]
        outputs["action_mapping"] = action_mapping
        outputs["best_action_sequence"] = []
        outputs["debug_info"] = []
        outputs["entities"] = []
        outputs["logical_form"] = []
        outputs["answer"] = []
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            outputs["logical_form"].append([])
            if i in best_final_states:
                all_action_indices = [
                    best_final_states[i][j].action_history[0]
                    for j in range(len(best_final_states[i]))
                ]
                found_denotation = False
                for action_indices in all_action_indices:
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_indices
                    ]
                    has_logical_form = False
                    try:
                        logical_form = world[
                            i].action_sequence_to_logical_form(action_strings)
                        has_logical_form = True
                    except ParsingError:
                        logical_form = "Error producing logical form"
                    if target_list is not None:
                        denotation_correct = world[i].evaluate_logical_form(
                            logical_form, target_list[i])
                    else:
                        denotation_correct = False
                    if not found_denotation:
                        try:
                            denotation = world[i].execute(logical_form)
                            if denotation:
                                outputs["answer"].append(denotation)
                                found_denotation = True
                        except ExecutionError:
                            pass
                        if found_denotation:
                            if has_logical_form:
                                self._has_logical_form(1.0)
                            else:
                                self._has_logical_form(0.0)
                            if target_list:
                                self._denotation_accuracy(
                                    1.0 if denotation_correct else 0.0)
                            outputs["best_action_sequence"].append(
                                action_strings)
                    outputs["logical_form"][-1].append(logical_form)
                if not found_denotation:
                    outputs["answer"].append(None)
                    self._denotation_accuracy(0.0)
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["entities"].append(world[i].table_graph.entities)
            else:
                self._has_logical_form(0.0)
                self._denotation_accuracy(0.0)

        if metadata is not None:
            outputs["question_tokens"] = [
                x["question_tokens"] for x in metadata
            ]

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in the ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict["action_mapping"]
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict["debug_info"]
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info["predicted_action"] = predicted_action
                considered_actions = action_debug_info["considered_actions"]
                probabilities = action_debug_info["probabilities"]
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info["considered_actions"] = considered_actions
                action_info["action_probabilities"] = probabilities
                action_info["question_attention"] = action_debug_info.get(
                    "question_attention", [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
class NlvrCoverageSemanticParser(NlvrSemanticParser):
    """
    ``NlvrSemanticCoverageParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of annotated logical forms by maximizing coverage of the output sequences over a prespecified
    agenda. In addition to the signal from coverage, we also compute the denotations given by the
    logical forms and define a hybrid cost based on coverage and denotation errors. The training
    process then minimizes the expected value of this cost over an approximate set of logical forms
    produced by the parser, obtained by performing beam search.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the DecoderStep.
    beam_size : ``int``
        Beam size for the beam search used during training.
    max_num_finished_states : ``int``, optional (default=None)
        Maximum number of finished states the trainer should compute costs for.
    normalize_beam_score_by_length : ``bool``, optional (default=False)
        Should the log probabilities be normalized by length before renormalizing them? Edunov et
        al. do this in their work, but we found that not doing it works better. It's possible they
        did this because their task is NMT, and longer decoded sequences are not necessarily worse,
        and shouldn't be penalized, while we will mostly want to penalize longer logical forms.
    max_decoding_steps : ``int``
        Maximum number of steps for the beam search during training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    checklist_cost_weight : ``float``, optional (default=0.6)
        Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
        weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
        denotation accuracy.
    dynamic_cost_weight : ``Dict[str, Union[int, float]]``, optional (default=None)
        A dict containing keys ``wait_num_epochs`` and ``rate`` indicating the number of steps
        after which we should start decreasing the weight on checklist cost in favor of denotation
        cost, and the rate at which we should do it. We will decrease the weight in the following
        way - ``checklist_cost_weight = checklist_cost_weight - rate * checklist_cost_weight``
        starting at the apropriate epoch.  The weight will remain constant if this is not provided.
    penalize_non_agenda_actions : ``bool``, optional (default=False)
        Should we penalize the model for producing terminal actions that are outside the agenda?
    initial_mml_model_file : ``str`` , optional (default=None)
        If you want to initialize this model using weights from another model trained using MML,
        pass the path to the ``model.tar.gz`` file of that model here.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 max_num_finished_states: int = None,
                 dropout: float = 0.0,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 dynamic_cost_weight: Dict[str, Union[int, float]] = None,
                 penalize_non_agenda_actions: bool = False,
                 initial_mml_model_file: str = None) -> None:
        super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab,
                                                         sentence_embedder=sentence_embedder,
                                                         action_embedding_dim=action_embedding_dim,
                                                         encoder=encoder,
                                                         dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] = \
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrWorld just to get the number of terminals.
        self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
        self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                             action_embedding_dim=action_embedding_dim,
                                             input_attention=attention,
                                             dropout=dropout,
                                             use_coverage=True)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning("MML model file for initializing weights is passed, but does not exist."
                               " This is fine if you're just decoding.")

    def _initialize_weights_from_archive(self, archive: Archive) -> None:
        logger.info("Initializing weights from MML model.")
        model_parameters = dict(self.named_parameters())
        archived_parameters = dict(archive.model.named_parameters())
        sentence_embedder_weight = "_sentence_embedder.token_embedder_tokens.weight"
        if sentence_embedder_weight not in archived_parameters or \
           sentence_embedder_weight not in model_parameters:
            raise RuntimeError("When initializing model weights from an MML model, we need "
                               "the sentence embedder to be a TokenEmbedder using namespace called "
                               "tokens.")
        for name, weights in archived_parameters.items():
            if name in model_parameters:
                if name == "_sentence_embedder.token_embedder_tokens.weight":
                    # The shapes of embedding weights will most likely differ between the two models
                    # because the vocabularies will most likely be different. We will get a mapping
                    # of indices from this model's token indices to the archived model's and copy
                    # the tensor accordingly.
                    vocab_index_mapping = self._get_vocab_index_mapping(archive.model.vocab)
                    archived_embedding_weights = weights.data
                    new_weights = model_parameters[name].data.clone()
                    for index, archived_index in vocab_index_mapping:
                        new_weights[index] = archived_embedding_weights[archived_index]
                    logger.info("Copied embeddings of %d out of %d tokens",
                                len(vocab_index_mapping), new_weights.size()[0])
                else:
                    new_weights = weights.data
                logger.info("Copying parameter %s", name)
                model_parameters[name].data.copy_(new_weights)

    def _get_vocab_index_mapping(self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]:
        vocab_index_mapping: List[Tuple[int, int]] = []
        for index in range(self.vocab.get_vocab_size(namespace='tokens')):
            token = self.vocab.get_token_from_index(index=index, namespace='tokens')
            archived_token_index = archived_vocab.get_token_index(token, namespace='tokens')
            # Checking if we got the UNK token index, because we don't want all new token
            # representations initialized to UNK token's representation. We do that by checking if
            # the two tokens are the same. They will not be if the token at the archived index is
            # UNK.
            if archived_vocab.get_token_from_index(archived_token_index, namespace="tokens") == token:
                vocab_index_mapping.append((index, archived_token_index))
        return vocab_index_mapping

    @overrides
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRuleArray]],
                agenda: torch.LongTensor,
                identifier: List[str] = None,
                labels: torch.LongTensor = None,
                epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError("If you want a dynamic cost weight, use the "
                                   "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f", self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions,
                                                           checklist_target=checklist_target,
                                                           checklist_mask=checklist_mask,
                                                           checklist=initial_checklist))
        initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
                                         action_history=[[] for _ in range(batch_size)],
                                         score=initial_score_list,
                                         rnn_state=initial_rnn_state,
                                         grammar_state=initial_grammar_state,
                                         action_embeddings=action_embeddings,
                                         action_indices=action_indices,
                                         possible_actions=actions,
                                         worlds=worlds,
                                         label_strings=label_strings,
                                         checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(initial_state,
                                               self._decoder_step,
                                               self._get_state_cost)
        if identifier is not None:
            outputs['identifier'] = identifier
        best_action_sequences = outputs['best_action_sequences']
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs

    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask

    def _update_metrics(self,
                        action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrWorld]],
                        label_strings: List[List[str]],
                        possible_actions: List[List[ProductionRuleArray]],
                        agenda_data: List[List[int]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): action_strings contains k-best lists. This method only uses the top decoded
        # sequence currently. Maybe define top-k metrics?
        batch_size = len(worlds)
        for i in range(batch_size):
            # Using only the top decoded sequence per instance.
            instance_action_strings = action_strings[i][0] if action_strings[i] else []
            sequence_is_correct = [False]
            in_agenda_ratio = 0.0
            instance_possible_actions = possible_actions[i]
            if instance_action_strings:
                terminal_agenda_actions = []
                for rule_id in agenda_data[i]:
                    if rule_id == -1:
                        continue
                    action_string = instance_possible_actions[rule_id][0]
                    right_side = action_string.split(" -> ")[1]
                    if right_side.isdigit() or ('[' not in right_side and len(right_side) > 1):
                        terminal_agenda_actions.append(action_string)
                actions_in_agenda = [action in instance_action_strings for action in
                                     terminal_agenda_actions]
                in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda)
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                sequence_is_correct = self._check_denotation(instance_action_strings,
                                                             instance_label_strings,
                                                             instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)
            self._agenda_coverage(in_agenda_ratio)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
                'consistency': self._consistency.get_metric(reset),
                'agenda_coverage': self._agenda_coverage.get_metric(reset)
        }

    def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor:
        """
        Return the costs a finished state. Since it is a finished state, the group size will be 1,
        and hence we'll return just one cost.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # label_strings being None happens when we are testing. We do not care about the cost then.
        # TODO (pradeep): Make this cleaner.
        if state.label_strings is None or all(self._check_state_denotations(state)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost

    def _get_state_info(self, state) -> Dict[str, List]:
        """
        This method is here for debugging purposes, in case you want to look at the what the model
        is learning. It may be inefficient to call it while training the model on real data.
        """
        if len(state.batch_indices) == 1 and state.is_finished():
            costs = [float(self._get_state_cost(state).detach().cpu().numpy())]
        else:
            costs = []
        model_scores = [float(score.detach().cpu().numpy()) for score in state.score]
        all_actions = state.possible_actions[0]
        action_sequences = [[self._get_action_string(all_actions[action]) for action in history]
                            for history in state.action_history]
        agenda_sequences = []
        all_agenda_indices = []
        for agenda, checklist_target in zip(state.terminal_actions, state.checklist_target):
            agenda_indices = []
            for action, is_wanted in zip(agenda, checklist_target):
                action_int = int(action.detach().cpu().numpy())
                is_wanted_int = int(is_wanted.detach().cpu().numpy())
                if is_wanted_int != 0:
                    agenda_indices.append(action_int)
            agenda_sequences.append([self._get_action_string(all_actions[action])
                                     for action in agenda_indices])
            all_agenda_indices.append(agenda_indices)
        return {"agenda": agenda_sequences,
                "agenda_indices": all_agenda_indices,
                "history": action_sequences,
                "history_indices": state.action_history,
                "costs": costs,
                "scores": model_scores}
Ejemplo n.º 20
0
class Text2SqlParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    """
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                            action_embedding_dim=action_embedding_dim,
                                                            input_attention=input_attention,
                                                            predict_start_type_separately=False,
                                                            add_action_bias=self._add_action_bias,
                                                            dropout=dropout)
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs


    def _get_initial_state(self,
                           encoder_outputs: torch.Tensor,
                           mask: torch.Tensor,
                           actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_utterance,
                                                 encoder_output_list,
                                                 utterance_mask_list))

        initial_grammar_state = [self._create_grammar_state(actions[i]) for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """

        validation_correct = self._exact_match._total_value # pylint: disable=protected-access
        validation_total = self._exact_match._count # pylint: disable=protected-access
        return {
                '_exact_match_count': validation_correct,
                '_example_count': validation_total,
                'exact_match': self._exact_match.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'valid_sql_query': self._valid_sql_query.get_metric(reset),
                'action_similarity': self._action_similarity.get_metric(reset)
                }

    def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append((action, i))
            else:
                raise ValueError("The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)

                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[batch_index][action], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 21
0
class STS14Task(Task):
    '''
    Task class for Sentence Textual Similarity 14.
    Training data is STS12 and STS13 data, as provided in the dataset.
    '''
    def __init__(self, path, max_seq_len, name="sts14"):
        ''' '''
        super(STS14Task, self).__init__(name, 1)
        self.name = name
        self.pair_input = 1
        self.categorical = 0
        #self.val_metric = "%s_accuracy" % self.name
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer = Average()
        self.load_data(path, max_seq_len)

    def load_data(self, path, max_seq_len):
        '''
        Process the dataset located at path.

        TODO: preprocess and store data so don't have to wait?

        Args:
            - path (str): path to data
        '''

        def load_year_split(path):
            sents1, sents2, targs = [], [], []
            input_files = glob.glob('%s/STS.input.*.txt' % path)
            targ_files = glob.glob('%s/STS.gs.*.txt' % path)
            input_files.sort()
            targ_files.sort()
            for inp, targ in zip(input_files, targ_files):
                topic_sents1, topic_sents2, topic_targs = \
                        load_file(path, inp, targ)
                sents1 += topic_sents1
                sents2 += topic_sents2
                targs += topic_targs
            assert len(sents1) == len(sents2) == len(targs)
            return sents1, sents2, targs

        def load_file(path, inp, targ):
            sents1, sents2, targs = [], [], []
            with open(inp) as fh, open(targ) as gh:
                for raw_sents, raw_targ in zip(fh, gh):
                    raw_sents = raw_sents.split('\t')
                    sent1 = process_sentence(raw_sents[0], max_seq_len)
                    sent2 = process_sentence(raw_sents[1], max_seq_len)
                    if not sent1 or not sent2:
                        continue
                    sents1.append(sent1)
                    sents2.append(sent2)
                    targs.append(float(raw_targ) / 5) # rescale for cosine
            return sents1, sents2, targs

        sort_data = lambda s1, s2, t: \
            sorted(zip(s1, s2, t), key=lambda x: (len(x[0]), len(x[1])))
        unpack = lambda x: [l for l in map(list, zip(*x))]

        sts2topics = {
            12: ['MSRpar', 'MSRvid', 'SMTeuroparl', 'surprise.OnWN', \
                    'surprise.SMTnews'],
            13: ['FNWN', 'headlines', 'OnWN'],
            14: ['deft-forum', 'deft-news', 'headlines', 'images', \
                    'OnWN', 'tweet-news']
            }

        sents1, sents2, targs = [], [], []
        train_dirs = ['STS2012-train', 'STS2012-test', 'STS2013-test']
        for train_dir in train_dirs:
            res = load_year_split(path + train_dir + '/')
            sents1 += res[0]
            sents2 += res[1]
            targs += res[2]
        data = [(s1, s2, t) for s1, s2, t in zip(sents1, sents2, targs)]
        random.shuffle(data)
        sents1, sents2, targs = unpack(data)
        split_pt = int(.8 * len(sents1))
        tr_data = sort_data(sents1[:split_pt], sents2[:split_pt],
                targs[:split_pt])
        val_data = sort_data(sents1[split_pt:], sents2[split_pt:],
                targs[split_pt:])
        te_data = sort_data(*load_year_split(path))

        self.train_data_text = unpack(tr_data)
        self.val_data_text = unpack(val_data)
        self.test_data_text = unpack(te_data)
        log.info("\tFinished loading STS14 data.")

    def get_metrics(self, reset=False):
        return {'accuracy': self.scorer.get_metric(reset)}
Ejemplo n.º 22
0
class AtisSemanticParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    database_file: ``str``, optional (default=/atis/atis.db)
        The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need
        the file location to connect to it.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 database_file='/atis/atis.db') -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)


        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                              action_embedding_dim=action_embedding_dim,
                                                              input_attention=input_attention,
                                                              predict_start_type_separately=False,
                                                              add_action_bias=self._add_action_bias,
                                                              dropout=dropout,
                                                              num_layers=self._decoder_num_layers)

    @overrides
    def forward(self,  # type: ignore
                utterance: Dict[str, torch.LongTensor],
                world: List[AtisWorld],
                actions: List[List[ProductionRule]],
                linking_scores: torch.Tensor,
                target_action_sequence: torch.LongTensor = None,
                sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(initial_state,
                                                self._transition_function,
                                                (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            outputs['linking_scores'] = linking_scores
            if target_action_sequence is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._transition_function,
                                                               (target_action_sequence.unsqueeze(1),
                                                                target_mask.unsqueeze(1)))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            outputs['utterance'] = []
            outputs['tokenized_utterance'] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[(i, action_index)]
                                  for action_index in best_action_indices]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs['sql_queries'].append(sql_queries[i])

                outputs['utterance'].append(world[i].utterances[-1])
                outputs['tokenized_utterance'].append([token.text
                                                       for token in world[i].tokenized_utterances[-1]])
                outputs['entities'].append(world[i].entities)
                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRule]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             utterance_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1),
                                                     memory_cell[i].repeat(self._decoder_num_layers, 1),
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))
            else:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                     memory_cell[i],
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))


        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state

    @staticmethod
    def _get_type_vector(worlds: List[AtisWorld],
                         num_entities: int,
                         tensor: torch.Tensor = None) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        for batch_index, world in enumerate(worlds):
            types = []
            entities = [('number', entity)
                        if any([entity.startswith(numeric_nonterminal)
                                for numeric_nonterminal in NUMERIC_NONTERMINALS])
                        else ('string', entity)
                        for entity in world.entities]

            for entity_index, entity in enumerate(entities):
                # We need numbers to be first, then strings, since our entities are going to be
                # sorted. We do a split by type and then a merge later, and it relies on this sorting.
                if entity[0] == 'number':
                    entity_type = 1
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """
        return {
                'exact_match': self._exact_match.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'valid_sql_query': self._valid_sql_query.get_metric(reset),
                'action_similarity': self._action_similarity.get_metric(reset)
                }

    def _create_grammar_state(self,
                              world: AtisWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``AtisWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).to(entity_types.device).long()
                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)
                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = linked_rules
                entity_ids = [entity_map[entity] for entity in entities]
                entity_linking_scores = linking_scores[entity_ids]
                entity_type_tensor = entity_types[entity_ids]
                entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)
                                          .to(entity_types.device)
                                          .float())
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
class WikiTablesErmSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search
    for logical forms that yield the correct denotations.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    decoder_beam_size : ``int``
        Beam size to be used by the ExpectedRiskMinimization algorithm.
    decoder_num_finished_states : ``int``
        Number of finished states for which costs will be computed by the ExpectedRiskMinimization
        algorithm.
    max_decoding_steps : ``int``
        Maximum number of steps the decoder should take before giving up. Used both during training
        and evaluation. Passed to super class.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.  Passed to super class.
    normalize_beam_score_by_length : ``bool``, optional (default=False)
        Should we normalize the log-probabilities by length before renormalizing the beam? This was
        shown to work better for NML by Edunov et al., but that many not be the case for semantic
        parsing.
    checklist_cost_weight : ``float``, optional (default=0.6)
        Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
        weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
        denotation accuracy.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables. Passed to super class.
    mml_model_file : ``str``, optional (default=None)
        If you want to initialize this model using weights from another model trained using MML,
        pass the path to the ``model.tar.gz`` file of that model here.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 attention: Attention,
                 decoder_beam_size: int,
                 decoder_num_finished_states: int,
                 max_decoding_steps: int,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/',
                 mml_model_file: str = None) -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         add_action_bias=add_action_bias,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace,
                         tables_directory=tables_directory)
        # Not sure why mypy needs a type annotation for this!
        self._decoder_trainer: ExpectedRiskMinimization = \
                ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=self._max_decoding_steps,
                                         max_num_finished_states=decoder_num_finished_states)
        unlinked_terminals_global_indices = []
        global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
        for production, index in global_vocab.items():
            right_side = production.split(" -> ")[1]
            if right_side in types.COMMON_NAME_MAPPING:
                # This is a terminal production.
                unlinked_terminals_global_indices.append(index)
        self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
        self._decoder_step = LinkingCoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                               action_embedding_dim=action_embedding_dim,
                                                               input_attention=attention,
                                                               num_start_types=self._num_start_types,
                                                               predict_start_type_separately=True,
                                                               add_action_bias=self._add_action_bias,
                                                               mixture_feedforward=mixture_feedforward,
                                                               dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._agenda_coverage = Average()
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if mml_model_file is not None:
            if os.path.isfile(mml_model_file):
                archive = load_archive(mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning("MML model file for initializing weights is passed, but does not exist."
                               " This is fine if you're just decoding.")

    def _initialize_weights_from_archive(self, archive: Archive) -> None:
        logger.info("Initializing weights from MML model.")
        model_parameters = dict(self.named_parameters())
        archived_parameters = dict(archive.model.named_parameters())
        question_embedder_weight = "_question_embedder.token_embedder_tokens.weight"
        if question_embedder_weight not in archived_parameters or \
           question_embedder_weight not in model_parameters:
            raise RuntimeError("When initializing model weights from an MML model, we need "
                               "the question embedder to be a TokenEmbedder using namespace called "
                               "tokens.")
        for name, weights in archived_parameters.items():
            if name in model_parameters:
                if name == question_embedder_weight:
                    # The shapes of embedding weights will most likely differ between the two models
                    # because the vocabularies will most likely be different. We will get a mapping
                    # of indices from this model's token indices to the archived model's and copy
                    # the tensor accordingly.
                    vocab_index_mapping = self._get_vocab_index_mapping(archive.model.vocab)
                    archived_embedding_weights = weights.data
                    new_weights = model_parameters[name].data.clone()
                    for index, archived_index in vocab_index_mapping:
                        new_weights[index] = archived_embedding_weights[archived_index]
                    logger.info("Copied embeddings of %d out of %d tokens",
                                len(vocab_index_mapping), new_weights.size()[0])
                else:
                    new_weights = weights.data
                logger.info("Copying parameter %s", name)
                model_parameters[name].data.copy_(new_weights)

    def _get_vocab_index_mapping(self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]:
        vocab_index_mapping: List[Tuple[int, int]] = []
        for index in range(self.vocab.get_vocab_size(namespace='tokens')):
            token = self.vocab.get_token_from_index(index=index, namespace='tokens')
            archived_token_index = archived_vocab.get_token_index(token, namespace='tokens')
            # Checking if we got the UNK token index, because we don't want all new token
            # representations initialized to UNK token's representation. We do that by checking if
            # the two tokens are the same. They will not be if the token at the archived index is
            # UNK.
            if archived_vocab.get_token_from_index(archived_token_index, namespace="tokens") == token:
                vocab_index_mapping.append((index, archived_token_index))
        return vocab_index_mapping

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRule]],
                agenda: torch.LongTensor,
                example_lisp_string: List[str],
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        batch_size = list(question.values())[0].size(0)
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        checklist_states = []
        all_terminal_productions = [set(instance_world.terminal_productions.values())
                                    for instance_world in world]
        max_num_terminals = max([len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(actions,
                                                                           agenda_list,
                                                                           all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            checklist_states.append(ChecklistStatelet(terminal_actions=terminal_actions,
                                                      checklist_target=checklist_target,
                                                      checklist_mask=checklist_mask,
                                                      checklist=initial_checklist))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageState(batch_indices=list(range(batch_size)),  # type: ignore
                                      action_history=[[] for _ in range(batch_size)],
                                      score=initial_score_list,
                                      rnn_state=rnn_state,
                                      grammar_state=grammar_state,
                                      checklist_state=checklist_states,
                                      possible_actions=actions,
                                      extras=example_lisp_string,
                                      debug_info=None)

        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        outputs = self._decoder_trainer.decode(initial_state,  # type: ignore
                                               self._decoder_step,
                                               partial(self._get_state_cost, world))
        best_final_states = outputs['best_final_states']

        if not self.training:
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [action_mapping[(i, action_index)] for action_index in action_sequence]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [action in action_strings for action in agenda_actions]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             example_lisp_string,
                                             metadata,
                                             outputs)
        return outputs

    @staticmethod
    def _get_checklist_info(agenda: torch.LongTensor,
                            all_actions: List[ProductionRule],
                            terminal_productions: Set[str],
                            max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask

    def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0)
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [state.possible_actions[batch_index][i][0] for i in action_history]
        logical_form = world.get_logical_form(action_strings)
        lisp_string = state.extras[batch_index]
        if self._executor.evaluate_logical_form(logical_form, lisp_string):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        The base class returns a dict with dpd accuracy, denotation accuracy, and logical form
        percentage metrics. We add the agenda coverage metric here.
        """
        metrics = super().get_metrics(reset)
        metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset)
        return metrics
Ejemplo n.º 24
0
class SpiderParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:
        super().__init__(vocab)
        self.vocab = vocab
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._question_embedder = question_embedder
        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._num_entity_types = 9
        self._embedding_dim = question_embedder.get_output_dim()

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._linking_params = torch.nn.Linear(16, 1)
        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)

        num_edge_types = 3
        self._gnn = GatedGraphConv(self._embedding_dim,
                                   gnn_timesteps,
                                   num_edge_types=num_edge_types,
                                   dropout=dropout)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

        self.debug_parsing = debug_parsing

    @overrides
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        batch_size = len(world)
        device = utterance['tokens'].device

        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(
                initial_state, self._transition_function,
                (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
            return outputs

    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        """KAIMARY"""
        # TextFieldEmbedder needs a "token" key in the Dict
        """
        embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim])
        schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens])
        embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim])
        entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim])
        """
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)
        """KAIMARY"""
        # Variable: linking_scores
        # The entitiy linking score s(e, i) in the Krishnamurthy 2017
        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores
        """KAIMARY"""
        # linking_probabilities
        # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type
        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            """KAIMARY"""
            # Seq2VecEncoder get the hidden state of the last step as the unique output
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())
            """KAIMARY"""
            # Variable: entity_embedding
            # Rv in B Bogin 2019
            # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)
        """KAIMARY"""
        # Variable: link_embedding
        # Li in B Bogin 2019
        # Is an average of entity vectors weighted by the resulting distribution
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        """KAIMARY"""
        # Variable: encoder_input
        # [Wi, Li] in B Bogin 2019
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))
        """KAIMARY"""
        # Variable: max_entities_relevance
        # ρv = maxi plink(v | xi) in B Bogin 2019
        # Is the maximum probability of v for any word xi
        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        """KAIMARY"""
        # entity_type_embeddings ???
        # Variable: graph_initial_embedding
        # hv(0) in B Bogin 2019
        # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN
        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            """KAIMARY"""
            # Variable: entities_graph_encoding
            # φv in  B Bogin 2019
            # Is the final representation of each schema item after L steps
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            """KAIMARY"""
            # Variable: graph_link_embedding
            # Lφ,i in B Bogin 2019
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state

    @staticmethod
    def _get_neighbor_indices(worlds: List[SpiderWorld], num_entities: int,
                              device: torch.device) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[SpiderWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch
        have no neighbors, None will be returned.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.db_context.knowledge_graph.entities:
                if len(world.db_context.knowledge_graph.neighbors[entity]
                       ) > num_neighbors:
                    num_neighbors = len(
                        world.db_context.knowledge_graph.neighbors[entity])

        batch_neighbors = []
        no_entities_have_neighbors = True
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.db_context.knowledge_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.db_context.knowledge_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [
                    entity2index[n] for n in entity2neighbors[entity]
                ]
                if entity_neighbors:
                    no_entities_have_neighbors = False
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors,
                                                num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(
                neighbor_indexes, num_entities, lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        # It is possible that none of the entities has any neighbors, since our definition of the
        # knowledge graph allows it when no entities or numbers were extracted from the question.
        if no_entities_have_neighbors:
            return None
        return torch.tensor(batch_neighbors, device=device, dtype=torch.long)

    def _get_schema_graph_encoding(
        self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        max_num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        batch_size = initial_graph_embeddings.size(0)

        graph_data_list = []

        for batch_index, world in enumerate(worlds):
            x = initial_graph_embeddings[batch_index]

            adj_list = self._get_graph_adj_lists(
                initial_graph_embeddings.device, world,
                initial_graph_embeddings.size(1) - 1)
            graph_data = Data(x)
            for i, l in enumerate(adj_list):
                graph_data[f'edge_index_{i}'] = l
            graph_data_list.append(graph_data)

        batch = Batch.from_data_list(graph_data_list)

        gnn_output = self._gnn(batch.x, [
            batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)
        ])

        num_nodes = max_num_entities
        gnn_output = gnn_output.view(batch_size, num_nodes, -1)
        # entities_encodings = gnn_output
        entities_encodings = gnn_output[:, :max_num_entities]
        # global_node_encodings = gnn_output[:, max_num_entities]

        return entities_encodings

    @staticmethod
    def _get_graph_adj_lists(device,
                             world,
                             global_entity_id,
                             global_node=False):
        entity_mapping = {}
        for i, entity in enumerate(world.db_context.knowledge_graph.entities):
            entity_mapping[entity] = i
        entity_mapping['_global_'] = global_entity_id
        adj_list_own = []  # column--table
        adj_list_link = []  # table->table / foreign->primary
        adj_list_linked = []  # table<-table / foreign<-primary
        adj_list_global = []  # node->global

        # TODO: Prepare in advance?
        for key, neighbors in world.db_context.knowledge_graph.neighbors.items(
        ):
            idx_source = entity_mapping[key]
            for n_key in neighbors:
                idx_target = entity_mapping[n_key]
                if n_key.startswith("table") or key.startswith("table"):
                    adj_list_own.append((idx_source, idx_target))
                elif n_key.startswith("string") or key.startswith("string"):
                    adj_list_own.append((idx_source, idx_target))
                elif key.startswith("column:foreign"):
                    adj_list_link.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_link.append((idx_source_table, idx_target_table))
                elif n_key.startswith("column:foreign"):
                    adj_list_linked.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_linked.append(
                        (idx_source_table, idx_target_table))
                else:
                    assert False

            adj_list_global.append((idx_source, entity_mapping['_global_']))

        all_adj_types = [adj_list_own, adj_list_link, adj_list_linked]

        if global_node:
            all_adj_types.append(adj_list_global)

        return [
            torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1)
            if l else torch.tensor(l, device=device, dtype=torch.long)
            for l in all_adj_types
        ]

    def _create_grammar_state(
            self, world: SpiderWorld, possible_actions: List[ProductionRule],
            linking_scores: torch.Tensor,
            linked_actions_linking_scores: torch.Tensor,
            entity_types: torch.Tensor,
            entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(
                    global_action_tensors,
                    dim=0).to(global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [
                    rule.split(' -> ')[1].strip('[]\"')
                    for rule in linked_rules
                ]

                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = linking_scores[entity_ids]

                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[
                        entity_ids]

                if not self._decoder_use_graph_entities:
                    entity_type_tensor = entity_types[entity_ids]
                    entity_type_embeddings = (
                        self._entity_type_decoder_embedding(
                            entity_type_tensor).to(
                                entity_types.device).float())
                else:
                    entity_type_embeddings = entity_graph_encoding.index_select(
                        dim=0,
                        index=torch.tensor(
                            entity_ids, device=entity_graph_encoding.device))

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids), entity_action_linking_scores)
                else:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids))

        return GrammarStatelet(['statement'], translated_valid_actions,
                               self.is_nonterminal)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    def _get_linking_probabilities(
            self, worlds: List[SpiderWorld], linking_scores: torch.FloatTensor,
            question_mask: torch.LongTensor,
            entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "date_column:", followed by "number_column:", "string:", and "string_column:".
            # This is not a great assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.db_context.knowledge_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()

    @staticmethod
    def _query_difficulty(targets: torch.LongTensor, action_mapping,
                          batch_index):
        number_tables = len([
            action_mapping[(batch_index, int(a))] for a in targets
            if a >= 0 and action_mapping[(batch_index,
                                          int(a))].startswith('table_name')
        ])
        return number_tables > 1

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            '_match/exact_match': self._exact_match.get_metric(reset),
            'sql_match': self._sql_evaluator_match.get_metric(reset),
            '_others/action_similarity':
            self._action_similarity.get_metric(reset),
            '_match/match_single': self._acc_single.get_metric(reset),
            '_match/match_hard': self._acc_multi.get_metric(reset),
            'beam_hit': self._beam_hit.get_metric(reset)
        }

    @staticmethod
    def _get_type_vector(worlds: List[SpiderWorld], num_entities: int,
                         device) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        column_type_ids = [
            'boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time'
        ]

        for batch_index, world in enumerate(worlds):
            types = []

            for entity_index, entity in enumerate(
                    world.db_context.knowledge_graph.entities):
                parts = entity.split(':')
                entity_main_type = parts[0]
                if entity_main_type == 'column':
                    column_type = parts[1]
                    entity_type = column_type_ids.index(column_type)
                elif entity_main_type == 'string':
                    # cell value
                    entity_type = len(column_type_ids)
                elif entity_main_type == 'table':
                    entity_type = len(column_type_ids) + 1
                else:
                    raise (Exception("Unkown entity"))
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return torch.tensor(batch_types, dtype=torch.long,
                            device=device), entity_types

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRuleArray]],
                                    best_final_states: Mapping[
                                        int, Sequence[GrammarBasedState]],
                                    world: List[SpiderWorld],
                                    target_list: List[List[str]],
                                    outputs: Dict[str, Any]) -> None:
        batch_size = len(actions)

        outputs['predicted_sql_query'] = []

        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]

        for i in range(batch_size):
            # gold sql exactly as given
            original_gold_sql_query = ' '.join(
                world[i].get_query_without_table_hints())

            if i not in best_final_states:
                self._exact_match(0)
                self._action_similarity(0)
                self._sql_evaluator_match(0)
                self._acc_multi(0)
                self._acc_single(0)
                outputs['predicted_sql_query'].append('')
                continue

            best_action_indices = best_final_states[i][0].action_history[0]

            action_strings = [
                action_mapping[(i, action_index)]
                for action_index in best_action_indices
            ]
            predicted_sql_query = action_sequence_to_sql(action_strings,
                                                         add_table_names=True)
            outputs['predicted_sql_query'].append(
                sqlparse.format(predicted_sql_query, reindent=False))

            if target_list is not None:
                targets = target_list[i].data

                sequence_in_targets = self._action_history_match(
                    best_action_indices, targets)
                self._exact_match(sequence_in_targets)

                sql_evaluator_match = self._evaluate_func(
                    original_gold_sql_query, predicted_sql_query,
                    world[i].db_id)
                self._sql_evaluator_match(sql_evaluator_match)

                similarity = difflib.SequenceMatcher(None, best_action_indices,
                                                     targets)
                self._action_similarity(similarity.ratio())

                difficulty = self._query_difficulty(targets, action_mapping, i)
                if difficulty:
                    self._acc_multi(sql_evaluator_match)
                else:
                    self._acc_single(sql_evaluator_match)

            beam_hit = False
            for pos, final_state in enumerate(best_final_states[i]):
                action_indices = final_state.action_history[0]
                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in action_indices
                ]
                candidate_sql_query = action_sequence_to_sql(
                    action_strings, add_table_names=True)

                if target_list is not None:
                    correct = self._evaluate_func(original_gold_sql_query,
                                                  candidate_sql_query,
                                                  world[i].db_id)
                    if correct:
                        beam_hit = True
                    self._beam_hit(beam_hit)
Ejemplo n.º 25
0
class BaseRollinRolloutDecoder(SeqDecoder):
    """
    An base decoder with rollin and rollout formulation that will be used to define the other decoders such as autoregressive decoder, reinforce decoder, SEARNN decoder, etc.
    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    decoder_net : ``DecoderNet``, required
        Module that contains implementation of neural network for decoding output elements
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_embedder : ``Embedding``
        Embedder for target tokens.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    beam_size : ``int``, optional (default = 4)
        Width of the beam for beam search.
    tensor_based_metric : ``Metric``, optional (default = None)
        A metric to track on validation data that takes raw tensors when its called.
        This metric must accept two arguments when called: a batched tensor
        of predicted token indices, and a batched tensor of gold token indices.
    token_based_metric : ``Metric``, optional (default = None)
        A metric to track on validation data that takes lists of lists of tokens
        as input. This metric must accept two arguments when called, both
        of type `List[List[str]]`. The first is a predicted sequence for each item
        in the batch and the second is a gold sequence for each item in the batch.
    scheduled_sampling_ratio : ``float`` optional (default = 0)
        Defines ratio between teacher forced training and real output usage. If its zero
        (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output
        predictions in a single forward pass of the `decoder_net`.
    """

    default_implementation = "auto_regressive_seq_decoder"
    
    def __init__(self,
                 vocab: Vocabulary,
                 max_decoding_steps: int,
                 decoder_net: DecoderNet,
                 target_embedder: Embedding,
                 loss_criterion: LossCriterion,
                
                 generation_batch_size: int = 200,
                 use_in_seq2seq_mode: bool = False,
                 target_namespace: str = "tokens",
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 scheduled_sampling_k: int = 100,
                 scheduled_sampling_type: str = 'uniform',
                 rollin_mode: str = 'mixed',
                 rollout_mode: str = 'learned',

                 dropout: float = None,
                 start_token: str = START_SYMBOL,
                 end_token: str = END_SYMBOL,
                 num_decoder_layers: int = 1,
                 mask_pad_and_oov: bool = False,
                 tie_output_embedding: bool = False,

                 rollout_mixing_prob:float = 0.5,

                 use_bleu: bool = False,
                 use_hamming: bool = False,

                 sample_rollouts: bool = False,
                 beam_search_sampling_temperature: float = 1.,
                 top_k=0, 
                 top_p=0,
                 tensor_based_metric: Metric = None,
                 tensor_based_metric_mask: Metric = None,
                 token_based_metric: Metric = None,
                 eval_beam_size: int = 1,
                ) -> None:
        super().__init__(target_embedder)

        self.current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        self._vocab = vocab
        self._seq2seq_mode = use_in_seq2seq_mode

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._max_decoding_steps = max_decoding_steps
        self._generation_batch_size = generation_batch_size
        self._decoder_net = decoder_net

        self._target_namespace = target_namespace

        # TODO #4 (Kushal): Maybe make them modules so that we can add more of these later.
        # TODO #8 #7 (Kushal): Rename "mixed" rollin mode to "scheduled sampling".
        self._rollin_mode = rollin_mode
        self._rollout_mode = rollout_mode

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._scheduled_sampling_k = scheduled_sampling_k
        self._scheduled_sampling_type = scheduled_sampling_type
        self._sample_rollouts = sample_rollouts
        self._mask_pad_and_oov = mask_pad_and_oov

        self._rollout_mixing_prob = rollout_mixing_prob

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(start_token, self._target_namespace)
        self._end_index = self._vocab.get_token_index(end_token, self._target_namespace)

        self._padding_index = self._vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace)
        self._oov_index = self._vocab.get_token_index(DEFAULT_OOV_TOKEN, self._target_namespace)

        if self._mask_pad_and_oov:
            self._vocab_mask = torch.ones(self._vocab.get_vocab_size(self._target_namespace),
                                            device=self.current_device) \
                                    .scatter(0, torch.tensor([self._padding_index, self._oov_index, self._start_index],
                                                                device=self.current_device),
                                                0)
        if use_bleu:
            pad_index = self._vocab.get_token_index(self._vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        if use_hamming:
            self._hamming = HammingLoss()
        else:
            self._hamming = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1

        # TODO(Kushal): Pass in the arguments for sampled. Also, make sure you do not sample in case of Seq2Seq models.
        self._beam_search = SampledBeamSearch(self._end_index, 
                                                max_steps=max_decoding_steps, 
                                                beam_size=beam_size, temperature=beam_search_sampling_temperature)

        self._num_classes = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input." + 
                    f"target_embedder_dim: {self.target_embedder.get_output_dim()}, " + 
                    f"decoder input dim: {self._decoder_net.target_embedding_dim}."
            )

        self._ss_ratio = Average()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.training_iteration = 0
        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_net.get_output_dim(), self._num_classes)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    f"Can't tie embeddings with output linear layer, due to shape mismatch. " + 
                    f"{self._output_projection_layer.weight.shape} and {self.target_embedder.weight.shape}"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._loss_criterion = loss_criterion

        self._top_k = top_k
        self._top_p = top_p
        self._eval_beam_size = eval_beam_size
        self._mle_loss = MaximumLikelihoodLossCriterion()
        self._perplexity = Perplexity()

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._tensor_based_metric_mask = tensor_based_metric_mask
        
        self._decode_tokens = partial(decode_tokens, 
                                    vocab=self._vocab,
                                    start_index=self._start_index,
                                    end_index=self._end_index)
                                    
    def get_output_dim(self):
        return self._decoder_net.get_output_dim()

    def rollin_policy(self,
                      timestep: int,
                      last_predictions: torch.LongTensor,
                      target_tokens: Dict[str, torch.Tensor] = None,
                      rollin_mode = None) -> torch.LongTensor:
        """ Roll-in policy to use.
            This takes in targets, timestep and last_predictions, and decide
            which to use for taking next step i.e., generating next token.
            What to do is decided by rolling mode. Options are
                - teacher_forcing,
                - learned,
                - mixed,

            By default the mode is mixed with scheduled_sampling_ratio=0.0. This 
            defaults to teacher_forcing. You can also explicitly run with teacher_forcing
            mode.

        Arguments:
            timestep {int} -- Current timestep decides which target token to use.
                              In case of teacher_forcing this is usually {t-1}^{th} timestep
                              for predicting t^{th} token.
            last_predictions {torch.LongTensor} -- {t-1}^th token predicted by the model.

        Keyword Arguments:
            targets {torch.LongTensor} -- Targets value if it is available. This will be
                                           available in training mode but not in inference mode. (default: {None})
            rollin_mode {str} -- Rollin mode. Options are
                                  teacher_forcing, learned, scheduled-sampling (default: {'teacher_forcing'})
        Returns:
            torch.LongTensor -- The method returns input token for predicting next token.
        """
        rollin_mode = rollin_mode or self._rollin_mode

        # For first timestep, you are passing start token, so don't do anything smart.
        if (timestep == 0 or
           # If no targets, no way to do teacher_forcing, so use your own predictions.
           target_tokens is None  or
           rollin_mode == 'learned'):
            # shape: (batch_size,)
            return last_predictions

        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        if rollin_mode == 'teacher_forcing':
            # shape: (batch_size,)
            input_choices = targets[:, timestep]
        elif rollin_mode == 'mixed':
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - self._scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]
        else:
            raise ConfigurationError(f"invalid configuration for rollin policy: {rollin_mode}")
        return input_choices

    def copy_reference_policy(self,
                                timestep,
                                last_predictions: torch.LongTensor,
                                state: Dict[str, torch.Tensor],
                                target_tokens: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        seq_len = targets.size(1)
        
        batch_size = last_predictions.shape[0]
        if seq_len > timestep + 1:  # + 1 because timestep is an index, indexed at 0.
            # As we might be overriding  the next/predicted token/
            # We have to use the value corresponding to {t+1}^{th}
            # timestep.
            target_at_timesteps = targets[:, timestep + 1]
        else:
            # We have overshot the seq_len, so just repeat the
            # last token which is either _end_token or _pad_token.
            target_at_timesteps = targets[:, -1]

        # TODO: Add support to allow other types of reference policies.
        # target_logits: (batch_size, num_classes).
        # This tensor has 0 at targets and (near) -inf at other places.
        target_logits = (target_at_timesteps.new_zeros((batch_size, self._num_classes)) + 1e-45) \
                            .scatter_(dim=1,
                                      index=target_at_timesteps.unsqueeze(1),
                                      value=1.0).log()
        return target_logits, state
    
    def oracle_reference_policy(self, 
                                timestep: int,
                                last_predictions: torch.LongTensor,
                                state: Dict[str, torch.Tensor],
                                token_to_idx: Dict[str, int],
                                idx_to_token: Dict[int, str],
                               ) -> torch.FloatTensor:
        # TODO(Kushal): #5 This is a temporary fix. Ideally, we should have
        # an individual oracle for this which is different from cost function.
        assert hasattr(self._loss_criterion, "_rollout_cost_function") and \
                     hasattr(self._loss_criterion._rollout_cost_function, "_oracle"), \
                "For oracle reference policy, we will need noisy oracle loss function"

        start_time = time.time()
        target_logits, state = self._loss_criterion \
                                    ._rollout_cost_function \
                                    ._oracle \
                                    .reference_step_rollout(
                                        step=timestep,
                                        last_predictions=last_predictions,
                                        state=state,
                                        token_to_idx=token_to_idx,
                                        idx_to_token=idx_to_token)
        end_time = time.time()
        logger.info(f"Oracle Reference time: {end_time - start_time} s")
        return target_logits, state
    
    def rollout_policy(self,
                       timestep: int,
                       last_predictions: torch.LongTensor, 
                       state: Dict[str, torch.Tensor],
                       logits: torch.FloatTensor,
                       reference_policy:ReferencePolicyType,
                       rollout_mode: str = None,
                       rollout_mixing_func: RolloutMixingProbFuncType = None,
                      ) -> torch.FloatTensor:
        """Rollout policy to use.
           This takes in predicted logits at timestep {t}^{th} and
           depending upon the rollout_mode replaces some of the predictions
           with targets.

           The options for rollout mode are:
               - learned,
               - reference,
               - mixed.

        Arguments:
            timestep {int} -- Current timestep decides which target token to use.
                              In case of reference this is usually {t-1}^{th} timestep
                              for predicting t^{th} token.
            logits {torch.LongTensor} -- Logits generated by the model for {t}^{th} timestep.
                                         (batch_size, num_classes).

        Keyword Arguments:
            targets {torch.LongTensor} -- Targets value if it is available. This will be
                                available in training mode but not in inference mode. (default: {None})
            rollout_mode {str} -- Rollout mode: Options are:
                                    learned, reference, mixed. (default: {'learned'})
            rollout_mixing_func {RolloutMixingProbFuncType} -- Function to get mask to choose predicted logits vs targets in case of mixed
                                    rollouts.  (default: {0.5})

        Returns:
            torch.LongTensor -- The method returns logits with rollout policy applied.
        """
        rollout_mode = rollout_mode or self._rollout_mode
        output_logits = logits


        if rollout_mode == 'learned':
            # For learned rollout policy, just return the same logits.
            return output_logits, state

        target_logits, state = reference_policy(timestep, 
                                                last_predictions,
                                                state)

        batch_size = logits.size(0)
        if rollout_mode == 'reference':
             output_logits += target_logits
        elif rollout_mode == 'mixed':
            # Based on the mask (Value=1), copy target values.

            if rollout_mixing_func is not None:
                rollout_mixing_prob_tensor = rollout_mixing_func()
            else:
                # This returns a (batch_size, num_classes) boolean map where the rows are either all zeros or all ones.
                rollout_mixing_prob_tensor = torch.bernoulli(torch.ones(batch_size) * self._rollout_mixing_prob)

            rollout_mixing_mask = rollout_mixing_prob_tensor \
                                    .unsqueeze(1) \
                                    .expand(logits.shape) \
                                    .to(self.current_device)

            # The target_logits ranges from (-inf , 0), so, by adding those to logits,
            # we turn the values that are not target tokens to -inf, hence making the distribution
            # skew towards the target.
            output_logits += rollout_mixing_mask * target_logits
        else:
            raise ConfigurationError(f"Incompatible rollout mode: {rollout_mode}")
        return output_logits, state

    def take_step(self,
                  timestep: int,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor],
                  rollin_policy:RollinPolicyType=default_rollin_policy,
                  rollout_policy:RolloutPolicyType=default_rollout_policy,
                 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        input_choices = rollin_policy(timestep, last_predictions)

        # State timestep which we might in _prepare_output_projections.
        state['timestep'] = timestep
        # shape: (group_size, num_classes)
        class_logits, state = self._prepare_output_projections(
                                                last_predictions=input_choices,
                                                state=state)

        if not self.training and self._mask_pad_and_oov:
            # This implementation is copied from masked_log_softmax from allennlp.nn.util.
            mask = (self._vocab_mask.expand(class_logits.shape) + 1e-45).log()
            # shape: (group_size, num_classes)
            class_logits = class_logits + mask

        # shape: (group_size, num_classes)
        class_logits, state = rollout_policy(timestep, last_predictions, state, class_logits)
        class_logits = top_k_top_p_filtering(class_logits, self._top_k, self._top_p)
        return class_logits, state

    @overrides
    def forward(self,  # type: ignore
                encoder_out: Dict[str, torch.LongTensor] = {},
                target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        source_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        output_dict: Dict[str, torch.Tensor] = {}
        state: Dict[str, torch.Tensor] = {}
        decoder_init_state: Dict[str, torch.Tensor] = {}

        state.update(copy.copy(encoder_out))
        # In Seq2Seq setting, we will encode the source sequence,
        # and init the state object with encoder output and decoder
        # cell will use these encoder outputs for attention/initing
        # the decoder states.
        if self._seq2seq_mode:
            decoder_init_state = \
                        self._decoder_net.init_decoder_state(state)
            state.update(decoder_init_state)

       # Initialize target predictions with the start index.
        # shape: (batch_size,)
        start_predictions: torch.LongTensor = \
                self._get_start_predictions(state,
                                        target_tokens,
                                        self._generation_batch_size)
        
        # In case we have target_tokens, roll-in and roll-out
        # only till those many steps, otherwise we roll-out for
        # `self._max_decoding_steps`.
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets: torch.LongTensor = \
                    util.get_token_ids_from_text_field_tensors(target_tokens)

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps: int = target_sequence_length - 1
        else:
            num_decoding_steps: int = self._max_decoding_steps

        if target_tokens:
            decoder_output_dict, rollin_dict, rollout_dict_iter = \
                                        self._forward_loop(
                                                state=state,
                                                start_predictions=start_predictions,
                                                num_decoding_steps=num_decoding_steps,
                                                target_tokens=target_tokens)

            output_dict.update(decoder_output_dict)
            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                    vocab_namespace=self._target_namespace,
                                                    truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens

            decoded_targets = self._decode_tokens(targets,
                                    vocab_namespace=self._target_namespace,
                                    truncate=True)
            output_dict["decoded_targets"] = decoded_targets

            output_dict.update(self._loss_criterion(
                                            rollin_output_dict=rollin_dict, 
                                            rollout_output_dict_iter=rollout_dict_iter, 
                                            state=state, 
                                            target_tokens=target_tokens))

            mle_loss_output = self._mle_loss(
                                    rollin_output_dict=rollin_dict, 
                                    rollout_output_dict_iter=rollout_dict_iter, 
                                    state=state, 
                                    target_tokens=target_tokens)

            mle_loss = mle_loss_output['loss']
            self._perplexity(mle_loss)

        if not self.training:
            # While validating or testing we need to roll out the learned policy and the output
            # of this rollout is used to compute the secondary metrics
            # like BLEU.
            state: Dict[str, torch.Tensor] = {}
            state.update(copy.copy(encoder_out))
            state.update(decoder_init_state)

            rollout_output_dict = self.rollout(state,
                                        start_predictions,
                                        rollout_steps=num_decoding_steps,
                                        rollout_mode='learned',
                                        sampled=self._sample_rollouts,
                                        beam_size=self._eval_beam_size,
                                        # TODO #6 (Kushal): Add a reason why truncate_at_end_all is False here.
                                        truncate_at_end_all=False)

            output_dict.update(rollout_output_dict)

            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                vocab_namespace=self._target_namespace,
                                                truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens
            decoded_predictions = [predictions[0] \
                                    for predictions in output_dict["decoded_predictions"]]


            # shape (predictions): (batch_size, beam_size, num_decoding_steps)
            predictions = rollout_output_dict['predictions']

            # shape (best_predictions): (batch_size, num_decoding_steps)
            best_predictions = predictions[:, 0, :]

            if target_tokens:
                targets = util.get_token_ids_from_text_field_tensors(target_tokens)
                target_mask = util.get_text_field_mask(target_tokens)
                decoded_targets = self._decode_tokens(targets,
                                        vocab_namespace=self._target_namespace,
                                        truncate=True)

                # TODO #3 (Kushal): Maybe abstract out these losses and use loss_metric like AllenNLP uses.
                if self._bleu and target_tokens:
                    self._bleu(best_predictions, targets)

                if  self._hamming and target_tokens:
                    self._hamming(best_predictions, targets, target_mask)

                if self._tensor_based_metric is not None:
                    self._tensor_based_metric(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                    )
                if self._tensor_based_metric_mask is not None:
                    self._tensor_based_metric_mask(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                        mask=~target_mask,
                    )

                if self._token_based_metric is not None:
                    self._token_based_metric(  # type: ignore
                            predictions=decoded_predictions, 
                            gold_targets=decoded_targets,
                        )
        return output_dict

    @overrides
    def post_process(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        all_predicted_tokens = self._decode_tokens(predicted_indices, 
                                                    vocab_namespace=self._target_namespace,
                                                    truncate=True)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _apply_scheduled_sampling(self):
        if not self.training:
            raise RuntimeError("Scheduled Sampling can only be applied during training.")

        k = self._scheduled_sampling_k
        i = self.training_iteration
        if self._scheduled_sampling_type == 'uniform':
            # This is same scheduled sampling ratio set by config.
            pass
        elif self._scheduled_sampling_type == 'linear':
            self._scheduled_sampling_ratio = i/float(k)
        elif self._scheduled_sampling_type == 'inverse_sigmoid':
            self._scheduled_sampling_ratio = 1 - k/(k + math.exp(i/k))
        else:
            raise ConfigurationError(f"{self._scheduled_sampling_type} is not a valid scheduled sampling type.")

        self._ss_ratio(self._scheduled_sampling_ratio)

    def rollin(self,
               state: Dict[str, torch.Tensor],
               start_predictions: torch.LongTensor,
               rollin_steps: int,
               target_tokens: Dict[str, torch.LongTensor] = None,
               beam_size: int = 1,
               per_node_beam_size: int = None,
               sampled: bool = False,
               truncate_at_end_all: bool = False,
               rollin_mode: str = None,
              ):
        self.training_iteration += 1

        # We cannot make a class variable as default, so making default value
        # as None and in case it is None, setting it to num_classes.
        per_node_beam_size: int = per_node_beam_size or self._num_classes

        if self.training:
            self._apply_scheduled_sampling()

        rollin_policy = partial(self.rollin_policy,
                                target_tokens=target_tokens,
                                rollin_mode=rollin_mode)

        rolling_policy = partial(self.take_step,
                                 rollin_policy=rollin_policy)

        # shape (step_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes)
        step_predictions, log_probabilities, logits = \
                    self._beam_search.search(start_predictions,
                                                state,
                                                rolling_policy,
                                                max_steps=rollin_steps,
                                                beam_size=beam_size,
                                                per_node_beam_size=per_node_beam_size,
                                                sampled=sampled,
                                                truncate_at_end_all=truncate_at_end_all)

        logits = torch.cat(logits, dim=2)

        batch_size, beam_size, _ = step_predictions.shape
        start_prediction_length = start_predictions.size(0)
        step_predictions = torch.cat([start_predictions.unsqueeze(1) \
                                        .expand(batch_size, beam_size) \
                                        .reshape(batch_size, beam_size, 1),
                                        step_predictions],
                                     dim=-1)

        output_dict = {
            "predictions": step_predictions,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }
        return output_dict

    def rollin_parallel(self, 
                        state: Dict[str, torch.Tensor],
                        start_predictions: torch.LongTensor,
                        rollin_steps: int,
                        target_tokens: Dict[str, torch.LongTensor] = None,
                        beam_size: int = 1,
                        per_node_beam_size: int = None,
                        sampled: bool = False,
                        truncate_at_end_all: bool = False,
                        rollin_mode: str = None,
                    ):
        assert self._decoder_net.decodes_parallel, \
            "Rollin Parallel is only applicable for transformer style decoders" + \
            "that decode whole sequence in parallel."
        
        assert not rollin_mode or rollin_mode == "learned", \
            "Parallel Decoding only works when following " + \
            "teacher forcing rollin policy (rollin_mode='learned')."

        assert self._scheduled_sampling_ratio == 0, \
            "For learned rollin mode, scheduled sampling ratio should always be 0."

        self.training_iteration += 1

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        _, decoder_output = self._decoder_net(
            previous_state=state,
            previous_steps_predictions=target_embedding[:, :-1, :],
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_mask=target_mask[:, :-1],
        )

        # shape: (group_size, max_target_sequence_length, num_classes)
        logits = self._output_projection_layer(decoder_output)

        # Unsqueeze logit to add beam size dimension.
        logits = logits.unsqueeze(dim=1)

        log_probabilities, step_predictions = torch.max(logits, dim=-1)

        return {
            "predictions": step_predictions,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }

    def rollout(self,
                state: Dict[str, torch.Tensor],
                start_predictions: torch.LongTensor,
                rollout_steps: int,
                beam_size: int = None,
                per_node_beam_size: int = None,
                target_tokens: Dict[str, torch.LongTensor] = None,
                sampled: bool = True,
                truncate_at_end_all: bool = True,
                # shape (prediction_prefixes): (batch_size, prefix_length)
                prediction_prefixes: torch.LongTensor = None,
                target_prefixes: torch.LongTensor = None,
                rollout_mixing_func: RolloutMixingProbFuncType = None,
                reference_policy_type:str = "copy",
                rollout_mode: str = None,
               ):
        state['rollout_params'] = {}
        if reference_policy_type == 'oracle':
            reference_policy = partial(self.oracle_reference_policy,
                                        token_to_idx=self._vocab._token_to_index['target_tokens'],
                                        idx_to_token=self._vocab._index_to_token['target_tokens'],
                                       )
            num_steps_to_take = rollout_steps
            state['rollout_params']['rollout_prefixes'] = prediction_prefixes
        else:
            reference_policy = partial(self.copy_reference_policy,
                                        target_tokens=target_tokens)           
            num_steps_to_take = rollout_steps

        rollout_policy = partial(self.rollout_policy,
                                    rollout_mode=rollout_mode,
                                    rollout_mixing_func=rollout_mixing_func,
                                    reference_policy=reference_policy,
                                )
        rolling_policy=partial(self.take_step,
                               rollout_policy=rollout_policy)

        # shape (step_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes)
        step_predictions, log_probabilities, logits = \
                    self._beam_search.search(start_predictions,
                                                state,
                                                rolling_policy,
                                                max_steps=num_steps_to_take,
                                                beam_size=beam_size,
                                                per_node_beam_size=per_node_beam_size,
                                                sampled=sampled,
                                                truncate_at_end_all=truncate_at_end_all)

        logits = torch.cat(logits, dim=2)
        
        # Concatenate the start tokens to the predictions.They are not
        # added to the predictions by default.
        batch_size, beam_size, _ = step_predictions.shape

        start_prediction_length = start_predictions.size(0)
        step_predictions = torch.cat([start_predictions.unsqueeze(1) \
                                        .expand(batch_size, beam_size) \
                                        .reshape(batch_size, beam_size, 1),
                                        step_predictions],
                                        dim=-1)

        # There might be some predictions which might have been made by
        # rollin policy. If passed, concatenate them here.
        if prediction_prefixes is not None:
            prefixes_length = prediction_prefixes.size(1)
            step_predictions = torch.cat([prediction_prefixes.unsqueeze(1)\
                                            .expand(batch_size, beam_size, prefixes_length), 
                                         step_predictions],
                                         dim=-1)

        step_prediction_masks = self._get_mask(step_predictions \
                                                .reshape(batch_size * beam_size, -1)) \
                                        .reshape(batch_size, beam_size, -1)

        output_dict = {
            "predictions": step_predictions,
            "prediction_masks": step_prediction_masks,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }

        step_targets = None
        step_target_masks = None
        if target_tokens is not None:
            step_targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            if target_prefixes is not None:
                prefixes_length = target_prefixes.size(1)
                step_targets = torch.cat([target_prefixes, step_targets], dim=-1)

            step_target_masks = util.get_text_field_mask({'tokens': {'tokens': step_targets}})
            
            output_dict.update({
                "targets": step_targets,
                "target_masks": step_target_masks,
            })
        return output_dict

    def compute_sentence_probs(self,
                               sequences_dict: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        """ Given a batch of tokens, compute the per-token log probability of sequences
            given the trained model.

        Arguments:
            sequences_dict {Dict[str, torch.LongTensor]} -- The sequences that needs to be scored.

        Returns:
            seq_probs {torch.FloatTensor} -- Probabilities of the sequence.
            seq_lens {torch.LongTensor} -- Length of the non padded sequence.
            per_step_seq_probs {torch.LongTensor} -- Probability of per prediction in a sequence
        """
        state = {}
        sequences = util.get_token_ids_from_text_field_tensors(sequences_dict)

        batch_size = sequences.size(0)
        seq_len = sequences.size(1)
        start_predictions = self._get_start_predictions(state,
                                                        sequences_dict,
                                                        batch_size)
        
        # We are now computing probability considering given the sequence,
        # So, we will use rollin_mode=teacher_forcing as we want to select
        # token from the sequences for which we need to compute the probability.
        rollin_output_dict = self.rollin(state={},
                                            start_predictions=start_predictions,
                                            rollin_steps=seq_len - 1,
                                            target_tokens=sequences_dict,
                                            rollin_mode='teacher_forcing',
                                        )

        step_log_probs = F.log_softmax(rollin_output_dict['logits'].squeeze(1), dim=-1)
        per_step_seq_probs = torch.gather(step_log_probs, 2,
                                          sequences[:,1:].unsqueeze(2)) \
                                            .squeeze(2)

        sequence_mask = util.get_text_field_mask(sequences_dict)
        per_step_seq_probs_summed = torch.sum(per_step_seq_probs * sequence_mask[:, 1:], dim=-1)
        non_batch_dims = tuple(range(1, len(sequence_mask.shape)))

        # shape : (batch_size,)
        sequence_mask_sum = sequence_mask[:, 1:].sum(dim=non_batch_dims)

        # (seq_probs, seq_lens, per_step_seq_probs)
        return torch.exp(per_step_seq_probs_summed/sequence_mask_sum), \
                sequence_mask_sum, \
                torch.exp(per_step_seq_probs)

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      start_predictions: torch.LongTensor,
                      num_decoding_steps: int,
                      target_tokens: Dict[str, torch.LongTensor] = None,
                     ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        raise NotImplementedError()

    def _get_start_predictions(self,
              state: Dict[str, torch.Tensor],
              target_tokens: Dict[str, torch.LongTensor] = None,
              generation_batch_size:int = None) ->  torch.LongTensor:

        if self._seq2seq_mode:
           source_mask = state["source_mask"]
           batch_size = source_mask.size()[0]
        elif target_tokens:
            targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            batch_size = targets.size(0)
        else:
            batch_size = generation_batch_size

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        return torch.zeros((batch_size,),
                            dtype=torch.long,
                            device=self.current_device) \
                    .fill_(self._start_index)

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state.get("encoder_outputs", None)

        # shape: (group_size, max_input_sequence_length)
        source_mask = state.get("source_mask", None)

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions", None)

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1)

        if previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0:
            # There is no previous steps, except for start vectors in `last_predictions`
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
        else:
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions = torch.cat(
                [previous_steps_predictions, last_predictions_embeddings], 1
            )

        decoder_state, decoder_output = self._decoder_net(
            previous_state=state,
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_predictions=previous_steps_predictions,
        )
        
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state
        state.update(decoder_state)
        
        if self._decoder_net.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]
        
        # add dropout
        decoder_hidden_with_dropout = self._dropout(decoder_output)

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden_with_dropout)

        return output_projections, state
  
    def _get_mask(self, predictions) -> torch.FloatTensor:
        # SEARNN with KL might not produce the sequences that
        # match target sequence on length. This is especially true
        # with LM done with learned rollins. The pattern observed
        # here is that sequence lengths keep shrinking.

        # This code computes mask from predicted tokens by observing
        # first time eos token is produced. Everything after that is
        # masked out.
        mask = predictions.new_ones(predictions.shape)
        for i, indices in enumerate(predictions.detach().cpu().tolist()):
            if self._end_index in indices:
                end_idx = indices.index(self._end_index)
                mask[i, :end_idx + 1] = 1
                mask[i, end_idx + 1:] = 0
        return mask
        
    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}

        all_metrics.update({
            'ss_ratio': self._ss_ratio.get_metric(reset=reset),
            'training_iter': self.training_iteration,
            'perplexity': self._perplexity.get_metric(reset=reset),
        })

        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))

        if self._hamming and not self.training:
            all_metrics.update({'hamming': self._hamming.get_metric(reset=reset)})

        if self._loss_criterion and self._loss_criterion._shall_compute_rollout_loss:
            all_metrics.update(self._loss_criterion.get_metric(reset=reset))

        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(reset=reset)  # type: ignore
                )
            if self._token_based_metric is not None:
                all_metrics.update(self._token_based_metric.get_metric(reset=reset))  # type: ignore

        return all_metrics
Ejemplo n.º 26
0
class AtisSemanticParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    database_file: ``str``, optional (default=/atis/atis.db)
        The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need
        the file location to connect to it.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        utterance_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        decoder_beam_search: BeamSearch,
        max_decoding_steps: int,
        input_attention: Attention,
        add_action_bias: bool = True,
        training_beam_size: int = None,
        decoder_num_layers: int = 1,
        dropout: float = 0.0,
        rule_namespace: str = "rule_labels",
        database_file="/atis/atis.db",
    ) -> None:
        # Atis semantic parser init
        super().__init__(vocab)
        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        self._executor = SqlExecutor(database_file)
        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._num_entity_types = 2  # TODO(kevin): get this in a more principled way somehow?
        self._entity_type_decoder_embedding = Embedding(
            num_embeddings=self._num_entity_types,
            embedding_dim=action_embedding_dim)
        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._transition_function = LinkingTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout,
            num_layers=self._decoder_num_layers,
        )

    def forward(
        self,  # type: ignore
        utterance: Dict[str, torch.LongTensor],
        world: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
        target_action_sequence: torch.LongTensor = None,
        sql_queries: List[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions,
                                                linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(
                initial_state,
                self._transition_function,
                (target_action_sequence.unsqueeze(1),
                 target_mask.unsqueeze(1)),
            )
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {"action_mapping": action_mapping}
            outputs["linking_scores"] = linking_scores
            if target_action_sequence is not None:
                outputs["loss"] = self._decoder_trainer.decode(
                    initial_state,
                    self._transition_function,
                    (target_action_sequence.unsqueeze(1),
                     target_mask.unsqueeze(1)),
                )["loss"]
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )
            outputs["best_action_sequence"] = []
            outputs["debug_info"] = []
            outputs["entities"] = []
            outputs["predicted_sql_query"] = []
            outputs["sql_queries"] = []
            outputs["utterance"] = []
            outputs["tokenized_utterance"] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs["predicted_sql_query"].append("")
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in best_action_indices
                ]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(
                        predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs["sql_queries"].append(sql_queries[i])

                outputs["utterance"].append(world[i].utterances[-1])
                outputs["tokenized_utterance"].append([
                    token.text for token in world[i].tokenized_utterances[-1]
                ])
                outputs["entities"].append(world[i].entities)
                outputs["best_action_sequence"].append(action_strings)
                outputs["predicted_sql_query"].append(
                    sqlparse.format(predicted_sql_query, reindent=True))
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_initial_state(
        self,
        utterance: Dict[str, torch.LongTensor],
        worlds: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
    ) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance)

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities,
                                                embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i].repeat(
                            self._decoder_num_layers, 1),
                        memory_cell[i].repeat(self._decoder_num_layers, 1),
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))
            else:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i],
                        memory_cell[i],
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], actions[i],
                                       linking_scores[i], entity_types[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None,
        )
        return initial_state

    @staticmethod
    def _get_type_vector(
        worlds: List[AtisWorld],
        num_entities: int,
        tensor: torch.Tensor = None
    ) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        for batch_index, world in enumerate(worlds):
            types = []
            entities = [("number", entity) if any([
                entity.startswith(numeric_nonterminal)
                for numeric_nonterminal in NUMERIC_NONTERMINALS
            ]) else ("string", entity) for entity in world.entities]

            for entity_index, entity in enumerate(entities):
                # We need numbers to be first, then strings, since our entities are going to be
                # sorted. We do a split by type and then a merge later, and it relies on this sorting.
                if entity[0] == "number":
                    entity_type = 1
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """
        return {
            "exact_match": self._exact_match.get_metric(reset),
            "denotation_acc": self._denotation_accuracy.get_metric(reset),
            "valid_sql_query": self._valid_sql_query.get_metric(reset),
            "action_similarity": self._action_similarity.get_metric(reset),
        }

    def _create_grammar_state(
        self,
        world: AtisWorld,
        possible_actions: List[ProductionRule],
        linking_scores: torch.Tensor,
        entity_types: torch.Tensor,
    ) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``AtisWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities: Iterable[str] = world.entities

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = (torch.cat(
                    global_action_tensors,
                    dim=0).to(entity_types.device).long())
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]["global"] = (
                    global_input_embeddings,
                    global_output_embeddings,
                    list(global_action_ids),
                )
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = linked_rules
                entity_ids = [entity_map[entity] for entity in entities]
                entity_linking_scores = linking_scores[entity_ids]
                entity_type_tensor = entity_types[entity_ids]
                entity_type_embeddings = (
                    self._entity_type_decoder_embedding(entity_type_tensor).to(
                        entity_types.device).float())
                translated_valid_actions[key]["linked"] = (
                    entity_linking_scores,
                    entity_type_embeddings,
                    list(linked_action_ids),
                )

        return GrammarStatelet(["statement"], translated_valid_actions,
                               self.is_nonterminal)

    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict["action_mapping"]
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict["debug_info"]
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info["predicted_action"] = predicted_action
                considered_actions = action_debug_info["considered_actions"]
                probabilities = action_debug_info["probabilities"]
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info["considered_actions"] = considered_actions
                action_info["action_probabilities"] = probabilities
                action_info["utterance_attention"] = action_debug_info.get(
                    "question_attention", [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 27
0
class GateBidirectionalAttentionFlow(Model):
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 gate_sent_encoder: Seq2SeqEncoder,
                 gate_self_attention_layer: Seq2SeqEncoder,
                 span_gate: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 output_att_scores: bool = True,
                 sent_labels_src: str = 'sp',
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(GateBidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder

        self._dropout = torch.nn.Dropout(p=dropout)

        self._output_att_scores = output_att_scores

        self._sent_labels_src = sent_labels_src


        self._span_gate = span_gate

        if span_gate._gate_self_att:
            self._gate_sent_encoder = gate_sent_encoder
            self._gate_self_attention_layer = gate_self_attention_layer
        else:
            self._gate_sent_encoder = None
            self._gate_self_attention_layer = None

        self._f1_metrics = F1Measure(1)

        self.evd_ans_metric = Average()

        self._loss_trackers = {'loss': Average()}

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask)

        print('\nBert wordpiece size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=0)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1)
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()

        # gate prediction
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask,
                                                                         self._gate_self_attention_layer,
                                                                         self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
                          sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = (gate >= 0.3).long()
        gate = gate.view(batch_size, num_spans)

        output_dict = {
            "pred_sent_labels": gate, #[B, num_span]
            "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # Compute the loss for training.
        try:
            #loss = strong_sup_loss
            self._loss_trackers['loss'](loss)
            output_dict["loss"] = loss
        except RuntimeError:
            print('\n meta_data:', metadata)
            print(span_start_logits.shape)

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sp.append(metadata[i]['token_spans_sp'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_gate = gate[i].detach().cpu().numpy()
                    if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                    else:
                        self.evd_ans_metric(0)
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        p, r, evidence_f1_socre = self._f1_metrics.get_metric(reset)
        ans_in_evd = self.evd_ans_metric.get_metric(reset)
        metrics = {
                'evd_p': p,
                'evd_r': r,
                'evd_f1': evidence_f1_socre,
                'ans_in_evd': ans_in_evd
                }
        for name, tracker in self._loss_trackers.items():
            metrics[name] = tracker.get_metric(reset).item()
        return metrics

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
class NlvrCoverageSemanticParser(NlvrSemanticParser):
    """
    ``NlvrSemanticCoverageParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of annotated logical forms by maximizing coverage of the output sequences over a prespecified
    agenda. In addition to the signal from coverage, we also compute the denotations given by the
    logical forms and define a hybrid cost based on coverage and denotation errors. The training
    process then minimizes the expected value of this cost over an approximate set of logical forms
    produced by the parser, obtained by performing beam search.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the TransitionFunction.
    beam_size : ``int``
        Beam size for the beam search used during training.
    max_num_finished_states : ``int``, optional (default=None)
        Maximum number of finished states the trainer should compute costs for.
    normalize_beam_score_by_length : ``bool``, optional (default=False)
        Should the log probabilities be normalized by length before renormalizing them? Edunov et
        al. do this in their work, but we found that not doing it works better. It's possible they
        did this because their task is NMT, and longer decoded sequences are not necessarily worse,
        and shouldn't be penalized, while we will mostly want to penalize longer logical forms.
    max_decoding_steps : ``int``
        Maximum number of steps for the beam search during training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    checklist_cost_weight : ``float``, optional (default=0.6)
        Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
        weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
        denotation accuracy.
    dynamic_cost_weight : ``Dict[str, Union[int, float]]``, optional (default=None)
        A dict containing keys ``wait_num_epochs`` and ``rate`` indicating the number of steps
        after which we should start decreasing the weight on checklist cost in favor of denotation
        cost, and the rate at which we should do it. We will decrease the weight in the following
        way - ``checklist_cost_weight = checklist_cost_weight - rate * checklist_cost_weight``
        starting at the appropriate epoch.  The weight will remain constant if this is not provided.
    penalize_non_agenda_actions : ``bool``, optional (default=False)
        Should we penalize the model for producing terminal actions that are outside the agenda?
    initial_mml_model_file : ``str`` , optional (default=None)
        If you want to initialize this model using weights from another model trained using MML,
        pass the path to the ``model.tar.gz`` file of that model here.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 max_num_finished_states: int = None,
                 dropout: float = 0.0,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 dynamic_cost_weight: Dict[str, Union[int, float]] = None,
                 penalize_non_agenda_actions: bool = False,
                 initial_mml_model_file: str = None) -> None:
        super(NlvrCoverageSemanticParser,
              self).__init__(vocab=vocab,
                             sentence_embedder=sentence_embedder,
                             action_embedding_dim=action_embedding_dim,
                             encoder=encoder,
                             dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrWorld just to get the number of terminals.
        self._terminal_productions = set(
            NlvrWorld([]).terminal_productions.values())
        self._decoder_step = CoverageTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            num_start_types=1,
            activation=Activation.by_name('tanh')(),
            predict_start_type_separately=False,
            add_action_bias=False,
            dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight[
                "wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(
                    "MML model file for initializing weights is passed, but does not exist."
                    " This is fine if you're just decoding.")

    def _initialize_weights_from_archive(self, archive: Archive) -> None:
        logger.info("Initializing weights from MML model.")
        model_parameters = dict(self.named_parameters())
        archived_parameters = dict(archive.model.named_parameters())
        sentence_embedder_weight = "_sentence_embedder.token_embedder_tokens.weight"
        if sentence_embedder_weight not in archived_parameters or \
           sentence_embedder_weight not in model_parameters:
            raise RuntimeError(
                "When initializing model weights from an MML model, we need "
                "the sentence embedder to be a TokenEmbedder using namespace called "
                "tokens.")
        for name, weights in archived_parameters.items():
            if name in model_parameters:
                if name == "_sentence_embedder.token_embedder_tokens.weight":
                    # The shapes of embedding weights will most likely differ between the two models
                    # because the vocabularies will most likely be different. We will get a mapping
                    # of indices from this model's token indices to the archived model's and copy
                    # the tensor accordingly.
                    vocab_index_mapping = self._get_vocab_index_mapping(
                        archive.model.vocab)
                    archived_embedding_weights = weights.data
                    new_weights = model_parameters[name].data.clone()
                    for index, archived_index in vocab_index_mapping:
                        new_weights[index] = archived_embedding_weights[
                            archived_index]
                    logger.info("Copied embeddings of %d out of %d tokens",
                                len(vocab_index_mapping),
                                new_weights.size()[0])
                else:
                    new_weights = weights.data
                logger.info("Copying parameter %s", name)
                model_parameters[name].data.copy_(new_weights)

    def _get_vocab_index_mapping(
            self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]:
        vocab_index_mapping: List[Tuple[int, int]] = []
        for index in range(self.vocab.get_vocab_size(namespace='tokens')):
            token = self.vocab.get_token_from_index(index=index,
                                                    namespace='tokens')
            archived_token_index = archived_vocab.get_token_index(
                token, namespace='tokens')
            # Checking if we got the UNK token index, because we don't want all new token
            # representations initialized to UNK token's representation. We do that by checking if
            # the two tokens are the same. They will not be if the token at the archived index is
            # UNK.
            if archived_vocab.get_token_from_index(
                    archived_token_index, namespace="tokens") == token:
                vocab_index_mapping.append((index, archived_token_index))
        return vocab_index_mapping

    @overrides
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRule]],
            agenda: torch.LongTensor,
            identifier: List[str] = None,
            labels: torch.LongTensor = None,
            epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(terminal_actions=terminal_actions,
                                  checklist_target=checklist_target,
                                  checklist_mask=checklist_mask,
                                  checklist=initial_checklist))
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, worlds))
        if identifier is not None:
            outputs['identifier'] = identifier
        best_final_states = outputs['best_final_states']
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [
                state.action_history[0] for state in states
            ]
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs

    def _get_checklist_info(
        self, agenda: torch.LongTensor, all_actions: List[ProductionRule]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set(
            [int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list,
                                             dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask

    def _update_metrics(self, action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrWorld]],
                        label_strings: List[List[str]],
                        possible_actions: List[List[ProductionRule]],
                        agenda_data: List[List[int]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): action_strings contains k-best lists. This method only uses the top decoded
        # sequence currently. Maybe define top-k metrics?
        batch_size = len(worlds)
        for i in range(batch_size):
            # Using only the top decoded sequence per instance.
            instance_action_strings = action_strings[i][0] if action_strings[
                i] else []
            sequence_is_correct = [False]
            in_agenda_ratio = 0.0
            instance_possible_actions = possible_actions[i]
            if instance_action_strings:
                terminal_agenda_actions = []
                for rule_id in agenda_data[i]:
                    if rule_id == -1:
                        continue
                    action_string = instance_possible_actions[rule_id][0]
                    right_side = action_string.split(" -> ")[1]
                    if right_side.isdigit() or ('[' not in right_side
                                                and len(right_side) > 1):
                        terminal_agenda_actions.append(action_string)
                actions_in_agenda = [
                    action in instance_action_strings
                    for action in terminal_agenda_actions
                ]
                in_agenda_ratio = sum(actions_in_agenda) / len(
                    actions_in_agenda)
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                sequence_is_correct = self._check_denotation(
                    instance_action_strings, instance_label_strings,
                    instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)
            self._agenda_coverage(in_agenda_ratio)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
            'consistency': self._consistency.get_metric(reset),
            'agenda_coverage': self._agenda_coverage.get_metric(reset)
        }

    def _get_state_cost(self, batch_worlds: List[List[NlvrWorld]],
                        state: CoverageState) -> torch.Tensor:
        """
        Return the cost of a finished state. Since it is a finished state, the group size will be
        1, and hence we'll return just one cost.

        The ``batch_worlds`` parameter here is because we need the world to check the denotation
        accuracy of the action sequence in the finished state.  Instead of adding a field to the
        ``State`` object just for this method, we take the ``World`` as a parameter here.
        """
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        instance_worlds = batch_worlds[state.batch_indices[0]]
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # extras being None happens when we are testing. We do not care about the cost
        # then.  TODO (pradeep): Make this cleaner.
        if state.extras is None or all(
                self._check_state_denotations(state, instance_worlds)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost

    def _get_state_info(
            self, state: CoverageState,
            batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]:
        """
        This method is here for debugging purposes, in case you want to look at the what the model
        is learning. It may be inefficient to call it while training the model on real data.
        """
        if len(state.batch_indices) == 1 and state.is_finished():
            costs = [
                float(
                    self._get_state_cost(batch_worlds,
                                         state).detach().cpu().numpy())
            ]
        else:
            costs = []
        model_scores = [
            float(score.detach().cpu().numpy()) for score in state.score
        ]
        all_actions = state.possible_actions[0]
        action_sequences = [[
            self._get_action_string(all_actions[action]) for action in history
        ] for history in state.action_history]
        agenda_sequences = []
        all_agenda_indices = []
        for checklist_state in state.checklist_state:
            agenda_indices = []
            for action, is_wanted in zip(checklist_state.terminal_actions,
                                         checklist_state.checklist_target):
                action_int = int(action.detach().cpu().numpy())
                is_wanted_int = int(is_wanted.detach().cpu().numpy())
                if is_wanted_int != 0:
                    agenda_indices.append(action_int)
            agenda_sequences.append([
                self._get_action_string(all_actions[action])
                for action in agenda_indices
            ])
            all_agenda_indices.append(agenda_indices)
        return {
            "agenda": agenda_sequences,
            "agenda_indices": all_agenda_indices,
            "history": action_sequences,
            "history_indices": state.action_history,
            "costs": costs,
            "scores": model_scores
        }
Ejemplo n.º 29
0
class QuarelSemanticParser(Model):
    """
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
                                                          self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=False,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs

    @staticmethod
    def _get_type_vector(worlds: List[QuarelWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[QuarelWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0/self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        if self._denotation_only:
            metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)}
        else:
            metrics = {
                    'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                    'denotation_acc': self._denotation_accuracy.get_metric(reset),
                    'lf_percent': self._has_logical_form.get_metric(reset),
            }
        return metrics

    def _create_grammar_state(self,
                              world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet([START_SYMBOL],
                               translated_valid_actions,
                               type_declaration.is_nonterminal)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 30
0
class RedditTask(RankingTask):
    ''' Task class for Reddit data.  '''

    def __init__(self, path, max_seq_len, name, **kw):
        ''' '''
        super().__init__(name, **kw)
        self.scorer1 = Average()  # CategoricalAccuracy()
        self.scorer2 = None
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.files_by_split = {split: os.path.join(path, "%s.csv" % split) for
                               split in ["train", "val", "test"]}
        self.max_seq_len = max_seq_len

    def get_split_text(self, split: str):
        ''' Get split text as iterable of records.

        Split should be one of 'train', 'val', or 'test'.
        '''
        return self.load_data(self.files_by_split[split])

    def load_data(self, path):
        ''' Load data '''
        with open(path, 'r') as txt_fh:
            for row in txt_fh:
                row = row.strip().split('\t')
                if len(row) < 4 or not row[2] or not row[3]:
                    continue
                sent1 = process_sentence(row[2], self.max_seq_len)
                sent2 = process_sentence(row[3], self.max_seq_len)
                targ = 1
                yield (sent1, sent2, targ)

    def get_sentences(self) -> Iterable[Sequence[str]]:
        ''' Yield sentences, used to compute vocabulary. '''
        for split in self.files_by_split:
            # Don't use test set for vocab building.
            if split.startswith("test"):
                continue
            path = self.files_by_split[split]
            for sent1, sent2, _ in self.load_data(path):
                yield sent1
                yield sent2

    def count_examples(self):
        ''' Compute here b/c we're streaming the sentences. '''
        example_counts = {}
        for split, split_path in self.files_by_split.items():
            example_counts[split] = sum(1 for line in open(split_path))
        self.example_counts = example_counts

    def process_split(self, split, indexers) -> Iterable[Type[Instance]]:
        ''' Process split text into a list of AllenNLP Instances. '''
        def _make_instance(input1, input2, labels):
            d = {}
            d["input1"] = sentence_to_text_field(input1, indexers)
            #d['sent1_str'] = MetadataField(" ".join(input1[1:-1]))
            d["input2"] = sentence_to_text_field(input2, indexers)
            #d['sent2_str'] = MetadataField(" ".join(input2[1:-1]))
            d["labels"] = LabelField(labels, label_namespace="labels",
                                     skip_indexing=True)
            return Instance(d)

        for sent1, sent2, trg in split:
            yield _make_instance(sent1, sent2, trg)

    def get_metrics(self, reset=False):
        '''Get metrics specific to the task'''
        acc = self.scorer1.get_metric(reset)
        return {'accuracy': acc}
Ejemplo n.º 31
0
class BCECentreDistanceBoxModel(BCEBoxModel):
    def __init__(self,
                 num_entities: int,
                 num_relations: int,
                 embedding_dim: int,
                 box_type: str = 'SigmoidBoxTensor',
                 single_box: bool = False,
                 softbox_temp: float = 10.,
                 margin: float = 1.,
                 number_of_negative_samples: int = 0,
                 debug: bool = False,
                 regularization_weight: float = 0,
                 init_interval_center: float = 0.25,
                 init_interval_delta: float = 0.1) -> None:
        super().__init__(num_entities,
                         num_relations,
                         embedding_dim,
                         box_type=box_type,
                         single_box=single_box,
                         softbox_temp=softbox_temp,
                         number_of_negative_samples=number_of_negative_samples,
                         debug=debug,
                         regularization_weight=regularization_weight,
                         init_interval_center=init_interval_center,
                         init_interval_delta=init_interval_delta)
        self.number_of_negative_samples = number_of_negative_samples
        self.centre_loss_metric = Average()
        self.loss_f_centre: torch.nn.modules._Loss = torch.nn.MarginRankingLoss(  # type: ignore
            margin=margin, reduction='mean')

    def get_scores(self, embeddings: Dict) -> torch.Tensor:
        p = self._get_triple_score(embeddings['h'], embeddings['t'],
                                   embeddings['r'])

        diff = (embeddings['h'].centre - embeddings['t'].centre).norm(p=1,
                                                                      dim=-1)
        pos_diff = torch.mean(
            diff[torch.where(embeddings['label'] == 1)]).view(1)
        neg_diff = torch.mean(
            diff[torch.where(embeddings['label'] == 0)]).view(1)

        if torch.isnan(pos_diff):
            pos_diff = torch.Tensor([0.])

        if torch.isnan(neg_diff):
            neg_diff = torch.Tensor([margin])

        return p, pos_diff, neg_diff

    def get_loss(self, scores: torch.Tensor,
                 label: torch.Tensor) -> torch.Tensor:
        log_p = scores[0]
        log1mp = log1mexp(log_p)
        logits = torch.stack([log1mp, log_p], dim=-1)
        centre_loss = self.loss_f_centre(scores[2], scores[1],
                                         torch.Tensor([1]))
        self.centre_loss_metric(centre_loss.item())
        loss = self.loss_f(logits,
                           label) + self.regularization_weight * centre_loss

        return loss

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:

        return {
            'hr_rank': self.head_replacement_rank_avg.get_metric(reset),
            'tr_rank': self.tail_replacement_rank_avg.get_metric(reset),
            'avg_rank': self.avg_rank.get_metric(reset),
            'hitsat10': self.hitsat10.get_metric(reset),
            'hr_mrr': self.head_replacement_mrr.get_metric(reset),
            'tr_mrr': self.tail_replacement_mrr.get_metric(reset),
            'int_volume_train': self.int_volume_train.get_metric(reset),
            'int_volume_dev': self.int_volume_dev.get_metric(reset),
            'regularization_loss': self.regularization_loss.get_metric(reset),
            'hr_hitsat1': self.head_hitsat1.get_metric(reset),
            'tr_hitsat1': self.tail_hitsat1.get_metric(reset),
            'hr_hitsat3': self.head_hitsat3.get_metric(reset),
            'tr_hitsat3': self.tail_hitsat3.get_metric(reset),
            'mrr': self.mrr.get_metric(reset),
            'centre_loss': self.centre_loss_metric.get_metric(reset)
        }

        return metrics
Ejemplo n.º 32
0
class WikiTablesErmSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search
    for logical forms that yield the correct denotations.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    input_attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to WikiTablesDecoderStep.
    decoder_beam_size : ``int``
        Beam size to be used by the ExpectedRiskMinimization algorithm.
    decoder_num_finished_states : ``int``
        Number of finished states for which costs will be computed by the ExpectedRiskMinimization
        algorithm.
    max_decoding_steps : ``int``
        Maximum number of steps the decoder should take before giving up. Used both during training
        and evaluation. Passed to super class.
    normalize_beam_score_by_length : ``bool``, optional (default=False)
        Should we normalize the log-probabilities by length before renormalizing the beam? This was
        shown to work better for NML by Edunov et al., but that many not be the case for semantic
        parsing.
    checklist_cost_weight : ``float``, optional (default=0.6)
        Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
        weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
        denotation accuracy.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables. Passed to super class.
    initial_mml_model_file : ``str``, optional (default=None)
        If you want to initialize this model using weights from another model trained using MML,
        pass the path to the ``model.tar.gz`` file of that model here.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 input_attention: Attention,
                 decoder_beam_size: int,
                 decoder_num_finished_states: int,
                 max_decoding_steps: int,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/',
                 initial_mml_model_file: str = None) -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace,
                         tables_directory=tables_directory)
        # Not sure why mypy needs a type annotation for this!
        self._decoder_trainer: ExpectedRiskMinimization = \
                ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=self._max_decoding_steps,
                                         max_num_finished_states=decoder_num_finished_states)
        unlinked_terminals_global_indices = []
        global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
        for production, index in global_vocab.items():
            right_side = production.split(" -> ")[1]
            if right_side in types.COMMON_NAME_MAPPING:
                # This is a terminal production.
                unlinked_terminals_global_indices.append(index)
        self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
        self._decoder_step = WikiTablesDecoderStep(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            num_start_types=self._num_start_types,
            num_entity_types=self._num_entity_types,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout,
            unlinked_terminal_indices=unlinked_terminals_global_indices)
        self._checklist_cost_weight = checklist_cost_weight
        self._agenda_coverage = Average()
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(
                    "MML model file for initializing weights is passed, but does not exist."
                    " This is fine if you're just decoding.")

    def _initialize_weights_from_archive(self, archive: Archive) -> None:
        logger.info("Initializing weights from MML model.")
        model_parameters = dict(self.named_parameters())
        archived_parameters = dict(archive.model.named_parameters())
        question_embedder_weight = "_question_embedder.token_embedder_tokens.weight"
        if question_embedder_weight not in archived_parameters or \
           question_embedder_weight not in model_parameters:
            raise RuntimeError(
                "When initializing model weights from an MML model, we need "
                "the question embedder to be a TokenEmbedder using namespace called "
                "tokens.")
        for name, weights in archived_parameters.items():
            if name in model_parameters:
                if name == question_embedder_weight:
                    # The shapes of embedding weights will most likely differ between the two models
                    # because the vocabularies will most likely be different. We will get a mapping
                    # of indices from this model's token indices to the archived model's and copy
                    # the tensor accordingly.
                    vocab_index_mapping = self._get_vocab_index_mapping(
                        archive.model.vocab)
                    archived_embedding_weights = weights.data
                    new_weights = model_parameters[name].data.clone()
                    for index, archived_index in vocab_index_mapping:
                        new_weights[index] = archived_embedding_weights[
                            archived_index]
                    logger.info("Copied embeddings of %d out of %d tokens",
                                len(vocab_index_mapping),
                                new_weights.size()[0])
                else:
                    new_weights = weights.data
                logger.info("Copying parameter %s", name)
                model_parameters[name].data.copy_(new_weights)

    def _get_vocab_index_mapping(
            self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]:
        vocab_index_mapping: List[Tuple[int, int]] = []
        for index in range(self.vocab.get_vocab_size(namespace='tokens')):
            token = self.vocab.get_token_from_index(index=index,
                                                    namespace='tokens')
            archived_token_index = archived_vocab.get_token_index(
                token, namespace='tokens')
            # Checking if we got the UNK token index, because we don't want all new token
            # representations initialized to UNK token's representation. We do that by checking if
            # the two tokens are the same. They will not be if the token at the archived index is
            # UNK.
            if archived_vocab.get_token_from_index(
                    archived_token_index, namespace="tokens") == token:
                vocab_index_mapping.append((index, archived_token_index))
        return vocab_index_mapping

    @overrides
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            agenda: torch.LongTensor,
            example_lisp_string: List[str]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        """
        batch_size = list(question.values())[0].size(0)
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistState(terminal_actions=terminal_actions,
                               checklist_target=checklist_target,
                               checklist_mask=checklist_mask,
                               checklist=initial_checklist))
        initial_info = self._get_initial_state_and_scores(
            question=question,
            table=table,
            world=world,
            actions=actions,
            example_lisp_string=example_lisp_string,
            add_world_to_initial_state=True,
            checklist_states=checklist_states)
        initial_state = initial_info["initial_state"]
        # TODO(pradeep): Keep track of debug info. It's not straightforward currently because the
        # ERM's decode does not return the best states.
        outputs = self._decoder_trainer.decode(initial_state,
                                               self._decoder_step,
                                               self._get_state_cost)
        if not self.training:
            # TODO(pradeep): Can move most of this block to super class.
            linking_scores = initial_info["linking_scores"]
            feature_scores = initial_info["feature_scores"]
            similarity_scores = initial_info["similarity_scores"]
            batch_size = list(question.values())[0].size(0)
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs['action_mapping'] = action_mapping
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = similarity_scores
            outputs['logical_form'] = []
            best_action_sequences = outputs['best_action_sequences']
            outputs["best_action_sequence"] = []
            outputs['debug_info'] = []
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                outputs['logical_form'].append([])
                if i in best_action_sequences:
                    for j, action_sequence in enumerate(
                            best_action_sequences[i]):
                        action_strings = [
                            action_mapping[(i, action_index)]
                            for action_index in action_sequence
                        ]
                        try:
                            logical_form = world[i].get_logical_form(
                                action_strings, add_var_function=False)
                            outputs['logical_form'][-1].append(logical_form)
                        except ParsingError:
                            logical_form = "Error producing logical form"
                        if j == 0:
                            # Updating denotation accuracy and has_logical_form only based on the
                            # first logical form.
                            if logical_form.startswith("Error"):
                                self._has_logical_form(0.0)
                            else:
                                self._has_logical_form(1.0)
                            if example_lisp_string:
                                self._denotation_accuracy(
                                    logical_form, example_lisp_string[i])
                            outputs['best_action_sequence'].append(
                                action_strings)
                    outputs['entities'].append(world[i].table_graph.entities)
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                else:
                    outputs['best_action_sequence'].append([])
                    outputs['logical_form'][-1].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
                self._agenda_coverage(in_agenda_ratio)
        return outputs

    @staticmethod
    def _get_checklist_info(
        agenda: torch.LongTensor, all_actions: List[ProductionRuleArray],
        terminal_productions: Set[str], max_num_terminals: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set(
            [int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list,
                                             dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask

    def _get_state_cost(self, state: WikiTablesDecoderState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(),
                                        min=0.0)
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [
            state.possible_actions[batch_index][i][0] for i in action_history
        ]
        logical_form = state.world[batch_index].get_logical_form(
            action_strings)
        lisp_string = state.example_lisp_string[batch_index]
        if self._denotation_accuracy.evaluate_logical_form(
                logical_form, lisp_string):
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        The base class returns a dict with dpd accuracy, denotation accuracy, and logical form
        percentage metrics. We add the agenda coverage metric here.
        """
        metrics = super().get_metrics(reset)
        metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset)
        return metrics

    @classmethod
    def from_params(cls, vocab,
                    params: Params) -> 'WikiTablesErmSemanticParser':
        question_embedder = TextFieldEmbedder.from_params(
            vocab, params.pop("question_embedder"))
        action_embedding_dim = params.pop_int("action_embedding_dim")
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        entity_encoder = Seq2VecEncoder.from_params(
            params.pop('entity_encoder'))
        mixture_feedforward_type = params.pop('mixture_feedforward', None)
        if mixture_feedforward_type is not None:
            mixture_feedforward = FeedForward.from_params(
                mixture_feedforward_type)
        else:
            mixture_feedforward = None
        input_attention = Attention.from_params(params.pop("attention"))
        decoder_beam_size = params.pop_int("decoder_beam_size")
        decoder_num_finished_states = params.pop_int(
            "decoder_num_finished_states", None)
        max_decoding_steps = params.pop_int("max_decoding_steps")
        normalize_beam_score_by_length = params.pop(
            "normalize_beam_score_by_length", False)
        use_neighbor_similarity_for_linking = params.pop_bool(
            "use_neighbor_similarity_for_linking", False)
        dropout = params.pop_float('dropout', 0.0)
        num_linking_features = params.pop_int('num_linking_features', 10)
        tables_directory = params.pop('tables_directory', '/wikitables/')
        rule_namespace = params.pop('rule_namespace', 'rule_labels')
        checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.6)
        mml_model_file = params.pop('mml_model_file', None)
        params.assert_empty(cls.__name__)
        return cls(
            vocab,
            question_embedder=question_embedder,
            action_embedding_dim=action_embedding_dim,
            encoder=encoder,
            entity_encoder=entity_encoder,
            mixture_feedforward=mixture_feedforward,
            input_attention=input_attention,
            decoder_beam_size=decoder_beam_size,
            decoder_num_finished_states=decoder_num_finished_states,
            max_decoding_steps=max_decoding_steps,
            normalize_beam_score_by_length=normalize_beam_score_by_length,
            checklist_cost_weight=checklist_cost_weight,
            use_neighbor_similarity_for_linking=
            use_neighbor_similarity_for_linking,
            dropout=dropout,
            num_linking_features=num_linking_features,
            tables_directory=tables_directory,
            rule_namespace=rule_namespace,
            initial_mml_model_file=mml_model_file)
Ejemplo n.º 33
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    To instantiate this model with parameters matching those in the original paper, simply use
    ``BidirectionalAttentionFlow.from_params(vocab, Params({}))``.  This will construct all of the
    various dependencies needed for the constructor for you.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    attention_similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    initializer : ``InitializerApplicator``
        We will use this to initialize the parameters in the model, calling ``initializer(self)``.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    evaluation_json_file : ``str``, optional
        If given, we will load this JSON into memory and use it to compute official metrics
        against.  We need this separately from the validation dataset, because the official metrics
        use all of the annotations, while our dataset reader picks the most frequent one.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 evaluation_json_file: str = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))
        initializer(self)
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._official_em = Average()
        self._official_f1 = Average()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        if evaluation_json_file:
            logger.info("Prepping official evaluation dataset from %s", evaluation_json_file)
            with open(evaluation_json_file) as dataset_file:
                dataset_json = json.load(dataset_file)
            question_to_answers = {}
            for article in dataset_json['data']:
                for paragraph in article['paragraphs']:
                    for question in paragraph['qas']:
                        question_id = question['id']
                        answers = [answer['text'] for answer in question['answers']]
                        question_to_answers[question_id] = answers

            self._official_eval_dataset = question_to_answers
        else:
            self._official_eval_dataset = None

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.  The ending position is `exclusive`, so our
            :class:`~allennlp.data.dataset_readers.SquadReader` adds a special ending token to the
            end of the passage, to allow for the last token to be included in the answer span.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `exclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (exclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {"span_start_logits": span_start_logits,
                       "span_start_probs": span_start_probs,
                       "span_end_logits": span_end_logits,
                       "span_end_probs": span_end_probs,
                       "best_span": best_span}
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None and self._official_eval_dataset:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                best_span_string = self._compute_official_metrics(metadata[i], predicted_span)  # type: ignore
                output_dict['best_span_str'].append(best_span_string)
        return output_dict

    def _compute_official_metrics(self,
                                  metadata: Dict[str, Any],
                                  predicted_span: Tuple[int, int]) -> str:
        passage = metadata['original_passage']
        offsets = metadata['token_offsets']
        question_id = metadata.get('id', None)
        start_offset = offsets[predicted_span[0]][0]
        end_offset = offsets[predicted_span[1]][1]
        span_string = passage[start_offset:end_offset]
        if question_id in self._official_eval_dataset:
            ground_truth = self._official_eval_dataset[question_id]
            exact_match = squad_eval.metric_max_over_ground_truths(
                    squad_eval.exact_match_score,
                    span_string,
                    ground_truth)
            f1_score = squad_eval.metric_max_over_ground_truths(
                    squad_eval.f1_score,
                    span_string,
                    ground_truth)
            self._official_em(100 * exact_match)
            self._official_f1(100 * f1_score)
        return span_string

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': self._official_em.get_metric(reset),
                'f1': self._official_f1.get_metric(reset),
                }

    def predict_span(self, question: TextField, passage: TextField) -> Dict[str, Any]:
        """
        Given a question and a passage, predicts the span in the passage that answers the question.

        Parameters
        ----------
        question : ``TextField``
        passage : ``TextField``
            A ``TextField`` containing the tokens in the passage.  Note that we typically add
            ``SquadReader.STOP_TOKEN`` as the final token in the passage, because we use exclusive
            span indices.  Be sure you've added that to the passage you pass in here.

        Returns
        -------
        A Dict containing:

        span_start_probs : numpy.ndarray
        span_end_probs : numpy.ndarray
        best_span : (int, int)
        """
        instance = Instance({'question': question, 'passage': passage})
        instance.index_fields(self.vocab)
        model_input = util.arrays_to_variables(instance.as_array_dict(),
                                               add_batch_dimension=True,
                                               for_training=False)
        output_dict = self.forward(**model_input)

        # Here we're just removing the batch dimension and converting things to numpy arrays /
        # tuples instead of pytorch variables.
        return {
                "span_start_probs": output_dict["span_start_probs"].data.squeeze(0).cpu().numpy(),
                "span_end_probs": output_dict["span_end_probs"].data.squeeze(0).cpu().numpy(),
                "best_span": tuple(output_dict["best_span"].data.squeeze(0).cpu().numpy()),
                }

    @staticmethod
    def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new()
                                  .resize_(batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        num_highway_layers = params.pop("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(params.pop("similarity_function"))
        modeling_layer = Seq2SeqEncoder.from_params(params.pop("modeling_layer"))
        span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder"))
        initializer = InitializerApplicator.from_params(params.pop("initializer", []))
        dropout = params.pop('dropout', 0.2)
        evaluation_json_file = params.pop('evaluation_json_file', None)
        mask_lstms = params.pop('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   num_highway_layers=num_highway_layers,
                   phrase_layer=phrase_layer,
                   attention_similarity_function=similarity_function,
                   modeling_layer=modeling_layer,
                   span_end_encoder=span_end_encoder,
                   initializer=initializer,
                   dropout=dropout,
                   mask_lstms=mask_lstms,
                   evaluation_json_file=evaluation_json_file)
class PTNChainBidirectionalAttentionFlow(Model):
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 gate_sent_encoder: Seq2SeqEncoder,
                 gate_self_attention_layer: Seq2SeqEncoder,
                 bert_projection: FeedForward,
                 span_gate: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 output_att_scores: bool = True,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(PTNChainBidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder

        self._dropout = torch.nn.Dropout(p=dropout)

        self._output_att_scores = output_att_scores

        self._span_gate = span_gate

        self._bert_projection = bert_projection

        #self._gate_sent_encoder = gate_sent_encoder
        self._gate_self_attention_layer = gate_self_attention_layer
        self._gate_sent_encoder = None
        self._gate_self_attention_layer = None

        self._f1_metrics = AttF1Measure(0.5, top_k=False)

        self._loss_trackers = {'loss': Average(),
                               'rl_loss': Average()}

        self.evd_sup_acc_metric = ChainAccuracy()
        self.evd_ans_metric = Average()
        self.evd_beam_ans_metric = Average()

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                transition_mask: torch.IntTensor = None,
                start_transition_mask: torch.Tensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # In this model, we only take the first chain in ``evd_chain_labels`` for supervision
        evd_chain_labels = evd_chain_labels[:, 0] if not evd_chain_labels is None else None
        # there may be some instances that we can't find any evd chain for training
        # In that case, use the mask to ignore those instances
        evd_instance_mask = (evd_chain_labels[:, 0] != 0).float() if not evd_chain_labels is None else None
        #print('passage size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        print('\nBert wordpiece size:', passage['bert'].shape)
        embedded_question = self._text_field_embedder(question)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, )
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()
        #print(context_mask.shape)
        # chain prediction
        # Shape(all_predictions): (batch_size, num_decoding_steps)
        # Shape(all_logprobs): (batch_size, num_decoding_steps)
        # Shape(seq_logprobs): (batch_size,)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(gate_probs): (batch_size * num_spans, 1)
        # Shape(gate_mask): (batch_size, num_spans)
        # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans)
        # Shape(orders): (batch_size, K, num_spans)
        all_predictions,    \
        all_logprobs,       \
        seq_logprobs,       \
        gate,               \
        gate_probs,         \
        gate_mask,          \
        g_att_score,        \
        orders = self._span_gate(embedded_passage, context_mask,
                                 embedded_question, ques_mask,
                                 evd_chain_labels,
                                 self._gate_self_attention_layer,
                                 self._gate_sent_encoder,
                                 transition_mask,
                                 start_transition_mask)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        output_dict = {
            "pred_sent_labels": gate.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "gate_probs": gate_probs.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "pred_sent_orders": orders, #[B, K, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # compute evd rl training metric, rewards, and loss
        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        evd_TP, evd_NP, evd_NT = self._f1_metrics(gate.squeeze(1).view(batch_size, num_spans),
                                                  sent_labels,
                                                  mask=gate_mask,
                                                  instance_mask=evd_instance_mask if self.training else None,
                                                  sum=False)
        # print("TP:", evd_TP)
        # print("NP:", evd_NP)
        # print("NT:", evd_NT)
        evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13)
        evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13)
        evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13))
        predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1), eos_idx=0)[0]
        gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0]
        # default to take multiple predicted chains, so unsqueeze dim 1
        self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1), gold_labels=evd_chain_labels,
                                predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask)
        print("gold chain:", evd_chain_labels)
        predict_mask = predict_mask.float().squeeze(1)
        rl_loss = -torch.mean(torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None], dim=1))
        # torch.cuda.empty_cache()
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # Compute before loss for rl
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            pred_chains_include_ans = []
            beam_pred_chains_include_ans = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                print("ans_sent_idxs:", metadata[i]['ans_sent_idxs'])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_orders = orders[i].detach().cpu().numpy()
                    if any([pred_sent_orders[0][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                        pred_chains_include_ans.append(1)
                    else:
                        self.evd_ans_metric(0)
                        pred_chains_include_ans.append(0)
                    if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) 
                                                                      for beam in range(len(pred_sent_orders))]):
                        self.evd_beam_ans_metric(1)
                        beam_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam_ans_metric(0)
                        beam_pred_chains_include_ans.append(0)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['pred_chains_include_ans'] = pred_chains_include_ans
            output_dict['beam_pred_chains_include_ans'] = beam_pred_chains_include_ans
            output_dict['_id'] = ids

        # Compute the loss for training.
        if evd_chain_labels is not None:
            try:
                loss = rl_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['rl_loss'](rl_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(output_dict['_id'])

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        p, r, evidence_f1_socre = self._f1_metrics.get_metric(reset)
        ans_in_evd = self.evd_ans_metric.get_metric(reset)
        beam_ans_in_evd = self.evd_beam_ans_metric.get_metric(reset)
        metrics = {
                'evd_p': p,
                'evd_r': r,
                'evd_f1': evidence_f1_socre,
                'ans_in_evd': ans_in_evd,
                'beam_ans_in_evd': beam_ans_in_evd,
                }
        for name, tracker in self._loss_trackers.items():
            metrics[name] = tracker.get_metric(reset).item()
        evd_sup_acc = self.evd_sup_acc_metric.get_metric(reset)
        metrics['evd_sup_acc'] = evd_sup_acc
        return metrics

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Ejemplo n.º 35
0
class Nlvr2EndToEndModuleNetwork(Model):
    """
    A re-implementation of `End-to-End Module Networks for Visual Question Answering
    <https://www.semanticscholar.org/paper/Learning-to-Reason%3A-End-to-End-Module-Networks-for-Hu-Andreas/5e07d6951b7bc0c4113313a9586ce8178eacdf57>`_

    This implementation is based on our semantic parsing framework, and uses marginal likelihood to
    train the parser when labeled action sequences are not available.  It is `not` an exact
    re-implementation, but rather a very similar model with some significant differences in how the
    grammar is used.

    Parameters
    ----------
    vocab : ``Vocabulary``
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    freeze_encoder: ``bool``, optional (default=True)
        If true, weights of the encoder will be frozen during training.
    dropout : ``float``, optional (default=0)
        Dropout to be applied to encoder outputs and in modules
    tokens_namespace : ``str``, optional (default=tokens)
        The vocabulary namespace to use for tokens.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    denotation_namespace : ``str``, optional (default=labels)
        The vocabulary namespace to use for output labels.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    num_parse_only_batches : ``int``, optional (default=0)
        We will use this many training batches of `only` parse supervision, not denotation
        supervision.  This is helpful in cases where learning the correct programs at the same time
        as learning the program executor, both from scratch is challenging.  This only works if you
        have labeled programs.
    use_gold_program_for_eval : ``bool``, optional (default=True)
        If true, we will use the gold program for evaluation when it is available (this only tests
        the program executor, not the parser).
    load_weights: ``str``, optional (default=None)
        Path from which to load model weights. If None or if path does not exist,
        no weights are loaded.
    use_modules: ``bool``, optional (default=True)
        If True, use modules and execute them according to programs. If False, use a
        feedforward network on top of the encoder to directly predict the label.
    positive_iou_threshold: ``float``, optional (default=0.5)
        Intersection-over-union (IOU) threshold to use for determining matches between
        ground-truth and predicted boxes in the faithfulness recall score.
    negative_iou_threshold: ``float``, optional (default=0.5)
        Intersection-over-union (IOU) threshold to use for determining matches between
        ground-truth and predicted boxes in the faithfulness precision score.
    nmn_settings: Dict, optional (default=None)
        A dictionary specifying choices determining architectures of the modules. This should
        not be None if use_modules == True.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: Seq2SeqEncoder,
        freeze_encoder: bool = False,
        dropout: float = 0.0,
        tokens_namespace: str = "tokens",
        rule_namespace: str = "rule_labels",
        denotation_namespace: str = "labels",
        num_parse_only_batches: int = 0,
        use_gold_program_for_eval: bool = True,
        load_weights: str = None,
        use_modules: bool = True,
        positive_iou_threshold: float = 0.5,
        negative_iou_threshold: float = 0.5,
        nmn_settings: Dict = None,
    ) -> None:
        super().__init__(vocab)
        self._encoder = encoder
        self._max_decoding_steps = 10
        self._add_action_bias = True
        self._dropout = torch.nn.Dropout(p=dropout)
        self._tokens_namespace = tokens_namespace
        self._rule_namespace = rule_namespace
        self._denotation_namespace = denotation_namespace
        self._denotation_accuracy = denotation_namespace
        self._num_parse_only_batches = num_parse_only_batches
        self._use_gold_program_for_eval = use_gold_program_for_eval
        self._nmn_settings = nmn_settings
        self._use_modules = use_modules
        self._training_batches_so_far = 0

        self._denotation_accuracy = CategoricalAccuracy()
        self._box_f1_score = ClassificationModuleScore(
            positive_iou_threshold=positive_iou_threshold,
            negative_iou_threshold=negative_iou_threshold,
        )
        self._best_box_f1_score = ClassificationModuleScore(
            positive_iou_threshold=positive_iou_threshold,
            negative_iou_threshold=negative_iou_threshold,
        )
        # TODO(mattg): use FullSequenceMatch instead of this.
        self._program_accuracy = Average()
        self._program_similarity = Average()
        self.loss = torch.nn.BCELoss()
        self.loss_with_logits = torch.nn.BCEWithLogitsLoss()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        action_embedding_dim = 100
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        if self._use_modules:
            self._language_parameters = VisualReasoningNlvr2Parameters(
                hidden_dim=self._encoder.get_output_dim(),
                initializer=self._encoder.encoder.model.init_bert_weights,
                max_boxes=self._nmn_settings["max_boxes"],
                dropout=dropout,
                nmn_settings=nmn_settings,
            )
        else:
            hid_dim = self._encoder.get_output_dim()
            self.logit_fc = torch.nn.Sequential(
                torch.nn.Linear(hid_dim * 2, hid_dim * 2),
                GeLU(),
                BertLayerNorm(hid_dim * 2, eps=1e-12),
                torch.nn.Linear(hid_dim * 2, 1),
            )
            self.logit_fc.apply(self._encoder.encoder.model.init_bert_weights)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        encoder_output_dim = self._encoder.get_output_dim()

        self._decoder_num_layers = 1

        self._beam_search = BeamSearch(beam_size=10)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=encoder_output_dim,
            action_embedding_dim=action_embedding_dim,
            input_attention=AdditiveAttention(vector_dim=encoder_output_dim,
                                              matrix_dim=encoder_output_dim),
            add_action_bias=self._add_action_bias,
            dropout=dropout,
            num_layers=self._decoder_num_layers,
        )

        # Our language is constant across instances, so we just create one up front that we can
        # re-use to construct the `GrammarStatelet`.
        self._world = VisualReasoningNlvr2Language(None, None, None, None,
                                                   None, None)

        if load_weights is not None:
            if not os.path.exists(load_weights):
                print('Could not find weights path: ' + load_weights +
                      '. Continuing without loading weights.')
            else:
                if torch.cuda.is_available():
                    state = torch.load(load_weights)
                else:
                    state = torch.load(load_weights, map_location="cpu")
                encoder_prefix = "_encoder"
                lang_params_prefix = "_language_parameters"
                for key in list(state.keys()):
                    if (key[:len(encoder_prefix)] != encoder_prefix
                            and key[:len(lang_params_prefix)] !=
                            lang_params_prefix):
                        del state[key]
                    if "relate_layer" in key:
                        del state[key]
                self.load_state_dict(state, strict=False)

        if freeze_encoder:
            for param in self._encoder.parameters():
                param.requires_grad = False

        self.consistency_group_map = {}

    def consistency(self, reset: bool = False):
        if reset:
            self.consistency_group_map = {}
        if len(self.consistency_group_map) == 0:
            return 0.0
        consistency = len([
            group for group in self.consistency_group_map
            if self.consistency_group_map[group] == True
        ]) / float(len(self.consistency_group_map))
        return consistency

    @overrides
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        visual_feat: torch.Tensor,
        pos: torch.Tensor,
        image_id: List[str],
        gold_question_attentions: torch.Tensor = None,
        gold_box_annotations: List[List[List[List[float]]]] = None,
        identifier: List[str] = None,
        logical_form: List[str] = None,
        actions: List[List[ProductionRule]] = None,
        target_action_sequence: torch.LongTensor = None,
        valid_target_sequence: torch.Tensor = None,
        denotation: torch.Tensor = None,
        metadata: Dict = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        batch_size, img_num, obj_num, feat_size = visual_feat.size()
        assert img_num == 2 and feat_size == 2048
        text_masks = util.get_text_field_mask(sentence)
        (l1, v1, text,
         vis_only1), x1 = self._encoder(sentence[self._tokens_namespace],
                                        text_masks, visual_feat[:, 0], pos[:,
                                                                           0])
        (l2, v2, text,
         vis_only2), x2 = self._encoder(sentence[self._tokens_namespace],
                                        text_masks, visual_feat[:, 1], pos[:,
                                                                           1])
        l_orig = torch.cat((l1.unsqueeze(1), l2.unsqueeze(1)), dim=1)
        v_orig = torch.cat((v1.unsqueeze(1), v2.unsqueeze(1)), dim=1)
        x_orig = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
        vis_only = torch.cat((vis_only1.unsqueeze(1), vis_only2.unsqueeze(1)),
                             dim=1)

        # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images)
        # Can also try concatenating (dim=-1) the two encodings
        encoded_sentence = text

        valid_target_sequence = valid_target_sequence.long()

        initial_state = self._get_initial_state(
            encoded_sentence[valid_target_sequence == 1],
            text_masks[valid_target_sequence == 1],
            actions,
        )
        initial_state.debug_info = [[] for _ in range(batch_size)]

        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_action_sequence = target_action_sequence[
                valid_target_sequence == 1]
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        losses = []
        if (self.training or self._use_gold_program_for_eval
            ) and target_action_sequence is not None:
            if valid_target_sequence.sum() > 0:
                outputs, final_states = self._decoder_trainer.decode(
                    initial_state,
                    self._transition_function,
                    (target_action_sequence.unsqueeze(1),
                     target_mask.unsqueeze(1)),
                )

                # B X TARGET X SENTENCE
                question_attention = [[
                    dbg["question_attention"]
                    for dbg in final_states[i][0].debug_info[0]
                ] for i in range(len(final_states))
                                      if valid_target_sequence[i].item() == 1]
                target_attn_loss = self._compute_target_attn_loss(
                    question_attention,
                    gold_question_attentions[valid_target_sequence ==
                                             1].squeeze(-1),
                )
                if not self._use_gold_program_for_eval:
                    outputs["loss"] += target_attn_loss
                else:
                    outputs["loss"] = torch.tensor(0.0)
                    if torch.cuda.is_available():
                        outputs["loss"] = outputs["loss"].cuda()
            else:
                final_states = None
                outputs["loss"] = torch.tensor(0.0).cuda()
            if (1 - valid_target_sequence).sum() > 0:
                if final_states is None:
                    final_states = {}
                initial_state = self._get_initial_state(
                    encoded_sentence[valid_target_sequence == 0],
                    text_masks[valid_target_sequence == 0],
                    actions,
                )
                remaining_states = self._beam_search.search(
                    self._max_decoding_steps,
                    initial_state,
                    self._transition_function,
                    keep_final_unfinished_states=False,
                )
                new_final_states = {}
                count = 0
                for i in range(valid_target_sequence.shape[0]):
                    if valid_target_sequence[i] < 0.5:
                        new_final_states[i] = remaining_states[i - count]
                    else:
                        new_final_states[i] = final_states[count]
                        count += 1
                final_states = new_final_states
            if final_states is None:
                len_final_states = 0
            else:
                len_final_states = len(final_states)
        else:
            initial_state = self._get_initial_state(encoded_sentence,
                                                    text_masks, actions)
            final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )

        action_mapping = {}
        for action_index, action in enumerate(actions[0]):
            action_mapping[action_index] = action[0]

        outputs["action_mapping"] = action_mapping
        outputs["debug_info"] = []
        outputs["modules_debug_info"] = []
        outputs["best_action_sequence"] = []
        outputs["image_id"] = []
        outputs["prediction"] = []
        outputs["label"] = []
        outputs["correct"] = []
        outputs["bboxes"] = []

        outputs = self._compute_parsing_validation_outputs(
            actions,
            target_action_sequence.shape[0],
            final_states,
            initial_state,
            [
                datum for i, datum in enumerate(metadata)
                if valid_target_sequence[i].item() == 1
            ],
            outputs,
            target_action_sequence,
        )

        if not self._use_modules:
            logits = self.logit_fc(x.view(batch_size, -1))
            if denotation is not None:
                self._denotation_accuracy(
                    torch.cat((torch.zeros_like(logits), logits), dim=-1),
                    denotation)
                if self.training:
                    outputs["loss"] += self.loss_with_logits(
                        logits.view(-1),
                        denotation.view(-1).float())
                    self._training_batches_so_far += 1
                else:
                    outputs["loss"] = self.loss_with_logits(
                        logits.view(-1),
                        denotation.view(-1).float())
            return outputs

        if self._nmn_settings["mask_non_attention"]:
            zero_one_mult = (torch.zeros_like(text_masks).unsqueeze(1).repeat(
                1, target_action_sequence.shape[1], 1))
            reformatted_gold_question_attentions = torch.where(
                gold_question_attentions.squeeze(-1) == -1,
                torch.zeros_like(gold_question_attentions.squeeze(-1)),
                gold_question_attentions.squeeze(-1),
            )
            pred_question_attention = [
                torch.stack([
                    torch.nn.functional.pad(
                        dbg["question_attention"].view(-1),
                        pad=(
                            0,
                            zero_one_mult.shape[-1] -
                            dbg["question_attention"].numel(),
                        ),
                    ) for dbg in final_states[i][0].debug_info[0]
                ]) for i in range(len(final_states))
            ]
            pred_question_attention = torch.stack([
                torch.nn.functional.pad(
                    attention,
                    pad=(0, 0, 0, zero_one_mult.shape[1] - attention.shape[0]),
                ) for attention in pred_question_attention
            ]).to(zero_one_mult.device)
            zero_one_mult.scatter_(
                2,
                reformatted_gold_question_attentions,
                torch.ones_like(reformatted_gold_question_attentions),
            )
            zero_one_mult[:, :, 0] = 1.0
            sep_indices = (
                (text_masks *
                 (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat(
                     batch_size, 1).to(text_masks.device))).argmax(1).long())
            sep_indices = (sep_indices.unsqueeze(1).repeat(
                1, text_masks.shape[1]).unsqueeze(1).repeat(
                    1, target_action_sequence.shape[1], 1))
            indices_dim2 = (torch.arange(
                text_masks.shape[1]).unsqueeze(0).unsqueeze(0).repeat(
                    batch_size, target_action_sequence.shape[1],
                    1).to(sep_indices.device).long())
            zero_one_mult = torch.where(
                sep_indices == indices_dim2,
                torch.ones_like(zero_one_mult),
                zero_one_mult,
            ).float()
            reshaped_questions = (
                sentence[self._tokens_namespace].unsqueeze(1).repeat(
                    1, target_action_sequence.shape[1],
                    1).view(-1, text_masks.shape[-1]))
            reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat(
                1, target_action_sequence.shape[1], 1, 1,
                1).view(-1, img_num, obj_num, visual_feat.shape[-1]))
            reshaped_pos = (pos.unsqueeze(1).repeat(
                1, target_action_sequence.shape[1], 1, 1,
                1).view(-1, img_num, obj_num, pos.shape[-1]))
            zero_one_mult = zero_one_mult.view(-1, text_masks.shape[-1])
            q_att_filter = zero_one_mult.sum(1) > 2
            (l1, v1, text, vis_only1), x1 = self._encoder(
                reshaped_questions[q_att_filter, :],
                zero_one_mult[q_att_filter, :],
                reshaped_visual_feat[q_att_filter, 0, :, :],
                reshaped_pos[q_att_filter, 0, :, :],
            )
            (l2, v2, text, vis_only2), x2 = self._encoder(
                reshaped_questions[q_att_filter, :],
                zero_one_mult[q_att_filter, :],
                reshaped_visual_feat[q_att_filter, 1, :, :],
                reshaped_pos[q_att_filter, 1, :, :],
            )
            l_cat = torch.cat((l1.unsqueeze(1), l2.unsqueeze(1)), dim=1)
            v_cat = torch.cat((v1.unsqueeze(1), v2.unsqueeze(1)), dim=1)
            x_cat = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
            l = [{} for _ in range(batch_size)]
            v = [{} for _ in range(batch_size)]
            x = [{} for _ in range(batch_size)]
            count = 0
            batch_index = -1
            for i in range(zero_one_mult.shape[0]):
                module_num = i % target_action_sequence.shape[1]
                if module_num == 0:
                    batch_index += 1
                    state = final_states[batch_index][0]
                    action_indices = state.action_history[0]
                    action_strings = [
                        action_mapping[action_index]
                        for action_index in action_indices
                    ]
                if q_att_filter[i].item():
                    l[batch_index][module_num] = self._dropout(l_cat[count])
                    v[batch_index][module_num] = self._dropout(v_cat[count])
                    x[batch_index][module_num] = self._dropout(x_cat[count])
                    count += 1
        else:
            l = self._dropout(l_orig)
            v = self._dropout(v_orig)
            x = self._dropout(x_orig)

        outputs["box_acc"] = [{} for _ in range(batch_size)]
        outputs["best_box_acc"] = [{} for _ in range(batch_size)]
        outputs["box_score"] = [{} for _ in range(batch_size)]
        outputs["box_f1"] = [{} for _ in range(batch_size)]
        outputs["box_f1_overall_score"] = [{} for _ in range(batch_size)]
        outputs["best_box_f1"] = [{} for _ in range(batch_size)]
        outputs["gold_box"] = []
        outputs["ious"] = [{} for _ in range(batch_size)]

        for batch_index in range(batch_size):
            if (self.training and self._training_batches_so_far <
                    self._num_parse_only_batches):
                continue
            if not final_states[batch_index]:
                logger.error(f"No pogram found for batch index {batch_index}")
                outputs["best_action_sequence"].append([])
                outputs["debug_info"].append([])
                continue

            # print(denotation.shape, denotation[batch_index])

            outputs["modules_debug_info"].append([])
            denotation_log_prob_list = []
            # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we
            # execute) at test time, just for efficiency.
            for state_index, state in enumerate(final_states[batch_index]):
                world = VisualReasoningNlvr2Language(
                    l[batch_index],
                    v[batch_index],
                    x[batch_index],
                    self._language_parameters,
                    metadata[batch_index]["tokenized_utterance"],
                    pos[batch_index],
                    self._nmn_settings,
                )
                action_indices = state.action_history[0]
                action_strings = [
                    action_mapping[action_index]
                    for action_index in action_indices
                ]
                # Shape: (num_denotations,)
                assert len(action_strings) == len(state.debug_info[0])
                # Plug in gold question attentions
                for i in range(len(state.debug_info[0])):
                    if (self._use_gold_program_for_eval
                            and valid_target_sequence[batch_index] == 1):
                        n_att_words = ((gold_question_attentions[batch_index,
                                                                 i] >=
                                        0).float().sum())
                        state.debug_info[0][i][
                            "question_attention"] = torch.zeros_like(
                                state.debug_info[0][i]["question_attention"])
                        if n_att_words > 0:
                            for j in gold_question_attentions[batch_index, i]:
                                if j >= 0:
                                    state.debug_info[0][i][
                                        "question_attention"][j] = (
                                            1.0 / n_att_words)
                    if (i not in l[batch_index]
                            and self._nmn_settings["mask_non_attention"]
                            and (action_strings[i][-4:] == "find"
                                 or action_strings[i][-6:] == "filter"
                                 or action_strings[i][-13:] == "with_relation"
                                 or action_strings[i][-7:] == "project")):
                        l[batch_index][i] = l_orig[batch_index, :, :]
                        v[batch_index][i] = v_orig[batch_index, :, :]
                        x[batch_index][i] = x_orig[batch_index, :]
                        world = VisualReasoningNlvr2Language(
                            l[batch_index],
                            v[batch_index],
                            x[batch_index],
                            self._language_parameters,
                            metadata[batch_index]["tokenized_utterance"],
                            pos[batch_index],
                            self._nmn_settings,
                        )
                world.parameters.train(self.training)

                state_denotation_probs = world.execute_action_sequence(
                    action_strings, state.debug_info[0])
                outputs["modules_debug_info"][batch_index].append(
                    world.modules_debug_info)

                # P(denotation | parse) * P(parse | question)
                world_log_prob = (state_denotation_probs + 1e-6).log()
                if not self._use_gold_program_for_eval:
                    world_log_prob += state.score[0]
                denotation_log_prob_list.append(world_log_prob)

            # P(denotation | parse) * P(parse | question) for the all programs on the beam.
            # Shape: (beam_size, num_denotations)
            denotation_log_probs = torch.stack(denotation_log_prob_list)
            # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question)
            # Shape: (num_denotations,)
            marginalized_denotation_log_probs = util.logsumexp(
                denotation_log_probs, dim=0)
            if denotation is not None:
                # This line is needed, otherwise we have numbers slightly exceeding 0..1. Should check why
                state_denotation_probs = state_denotation_probs.clamp(min=0,
                                                                      max=1)

                loss = self.loss(
                    state_denotation_probs.unsqueeze(0),
                    denotation[batch_index].unsqueeze(0).float(),
                ).view(1)
                losses.append(loss)
                self._denotation_accuracy(
                    torch.tensor(
                        [1 - state_denotation_probs,
                         state_denotation_probs]).to(denotation.device),
                    denotation[batch_index],
                )
                group_id = metadata[batch_index]["identifier"].split("-")
                group_id = group_id[0] + "-" + group_id[1] + "-" + group_id[-1]
                if group_id not in self.consistency_group_map:
                    self.consistency_group_map[group_id] = True
                if (state_denotation_probs.item() >= 0.5
                        and denotation[batch_index].item() < 0.5) or (
                            state_denotation_probs.item() < 0.5
                            and denotation[batch_index].item() > 0.5):
                    self.consistency_group_map[group_id] = False
                if (gold_box_annotations is not None
                        and len(gold_box_annotations[batch_index]) > 0):
                    box_f1_score, overall_f1_score_value = self._box_f1_score(
                        outputs["modules_debug_info"][batch_index][0],
                        gold_box_annotations[batch_index],
                        pos[batch_index],
                    )
                    outputs["box_f1"][batch_index] = box_f1_score
                    outputs["box_f1_overall_score"][
                        batch_index] = overall_f1_score_value
                    best_f1_predictions = self._best_box_f1_score.compute_best_box_predictions(
                        outputs["modules_debug_info"][batch_index][0],
                        gold_box_annotations[batch_index],
                        pos[batch_index],
                    )
                    best_f1, _ = self._best_box_f1_score(
                        best_f1_predictions,
                        gold_box_annotations[batch_index],
                        pos[batch_index],
                    )
                    outputs["best_box_f1"][batch_index] = best_f1
                    outputs["gold_box"].append(
                        gold_box_annotations[batch_index])
            outputs["image_id"].append(image_id[batch_index])
            outputs["prediction"].append(world_log_prob.exp())
            if denotation is not None:
                outputs["label"].append(denotation[batch_index])
                outputs["correct"].append(world_log_prob.exp().round().int() ==
                                          denotation[batch_index].int())
            outputs["bboxes"].append(pos[batch_index])
        if losses:
            outputs["loss"] += torch.stack(losses).mean()
        if self.training:
            self._training_batches_so_far += 1
        return outputs

    def _compute_parsing_validation_outputs(
        self,
        actions,
        batch_size,
        final_states,
        initial_state,
        metadata,
        outputs,
        target_action_sequence,
    ):
        if (not self.training and target_action_sequence is not None
                and target_action_sequence.numel() > 0):
            outputs["parse_correct"] = []

            # skip beam search if we already searched
            # if self._use_gold_program_for_eval:
            #     final_states = self._beam_search.search(self._max_decoding_steps,
            #                                             initial_state,
            #                                             self._transition_function,
            #                                             keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in final_states:
                    best_action_indices = final_states[i][0].action_history[0]
                    best_action_sequences[i] = best_action_indices

                    targets = target_action_sequence[i].data
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._program_accuracy(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._program_similarity(similarity.ratio())
                    outputs["parse_correct"].append(sequence_in_targets)
                else:
                    self._program_accuracy(0)
                    self._program_similarity(0)
                    continue

            batch_action_strings = self._get_action_strings(
                actions, best_action_sequences)
            if metadata is not None:
                outputs["sentence_tokens"] = [
                    x["tokenized_utterance"] for x in metadata
                ]
                outputs["utterance"] = [x["utterance"] for x in metadata]
                outputs["parse_gold"] = [x["gold"] for x in metadata]
            outputs["debug_info"] = []
            outputs["parse_predicted"] = []
            outputs["action_mapping"] = []
            for i in range(batch_size):
                if i in final_states:
                    outputs["debug_info"].append(
                        final_states[i][0].debug_info[0])  # type: ignore
                    outputs["action_mapping"].append(
                        [a[0] for a in actions[i]])
                    outputs["parse_predicted"].append(
                        self._world.action_sequence_to_logical_form(
                            batch_action_strings[i]))
            outputs["best_action_strings"] = batch_action_strings
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]

        return outputs

    def _get_initial_state(
        self,
        encoder_outputs: torch.Tensor,
        utterance_mask: torch.Tensor,
        actions: List[ProductionRule],
    ) -> GrammarBasedState:
        batch_size = encoder_outputs.size(0)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        # Use CLS states as final encoder outputs
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                encoder_outputs.shape[-1])
        initial_score = encoder_outputs.data.new_zeros(batch_size)
        attended_sentence, _ = self._transition_function.attend_on_question(
            final_encoder_output, encoder_outputs, utterance_mask)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                encoder_output = final_encoder_output[i].repeat(
                    self._decoder_num_layers, 1)
                cell = memory_cell[i].repeat(self._decoder_num_layers, 1)
            else:
                encoder_output = final_encoder_output[i]
                cell = memory_cell[i]
            initial_rnn_state.append(
                RnnStatelet(
                    encoder_output,
                    cell,
                    self._first_action_embedding,
                    attended_sentence[i],
                    encoder_output_list,
                    utterance_mask_list,
                ))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=[[] for _ in range(batch_size)],
        )
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {
            "denotation_acc": self._denotation_accuracy.get_metric(reset),
            "program_acc": self._program_accuracy.get_metric(reset),
            "consistency": self.consistency(reset),
        }
        for m in self._box_f1_score.modules:
            metrics["_" + m + "_box_f1"] = self._box_f1_score.get_metric(
                reset=False, module=m)["f1"]
            metrics["_" + m +
                    "_best_box_f1"] = self._best_box_f1_score.get_metric(
                        reset=False, module=m)["f1"]
        box_f1_pr = self._box_f1_score.get_metric(reset=reset)
        for key in box_f1_pr:
            metrics["overall_box_" + key] = box_f1_pr[key]
        best_box_f1_pr = self._best_box_f1_score.get_metric(reset=reset)
        for key in best_box_f1_pr:
            metrics["overall_best_box_" + key] = best_box_f1_pr[key]
        return metrics

    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input ``ProductionRules``,
        and we use those to embed the valid actions for every non-terminal type.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = self._world.get_nonterminal_productions()

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array[2], action_index))

            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors,
                                             dim=0).long()
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]["global"] = (
                global_input_embeddings,
                global_output_embeddings,
                list(global_action_ids),
            )

        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               self._world.is_nonterminal)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        # TODO(mattg): FIX THIS - I haven't touched this method yet.
        action_mapping = output_dict["action_mapping"]
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict["debug_info"]
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info["predicted_action"] = predicted_action
                considered_actions = action_debug_info["considered_actions"]
                probabilities = action_debug_info["probabilities"]
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info["considered_actions"] = considered_actions
                action_info["action_probabilities"] = probabilities
                action_info["utterance_attention"] = action_debug_info.get(
                    "question_attention", [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict

    def _compute_target_attn_loss(self, question_attention,
                                  gold_question_attention):
        attn_loss = 0
        normalizer = 0

        gold_question_attention = gold_question_attention.cpu().numpy()

        # TODO: Pad and batch this for performance
        for instance_attn, gld_instance_attn in zip(question_attention,
                                                    gold_question_attention):
            for step_attn, gld_step_attn in zip(instance_attn,
                                                gld_instance_attn):
                if gld_step_attn[0] == -1:
                    continue
                # consider only non-padding indices
                gld_step_attn = [a for a in gld_step_attn if a > -1]
                given_attn = step_attn[gld_step_attn]
                attn_loss += (given_attn.sum() + 1e-8).log()
                normalizer += 1

        if normalizer == 0:
            return 0
        else:
            return -1 * (attn_loss / normalizer)

    @classmethod
    def _get_action_strings(
        cls,
        possible_actions: List[List[ProductionRule]],
        action_indices: Dict[int, List[List[int]]],
    ) -> List[List[List[str]]]:
        """
        Takes a list of possible actions and indices of decoded actions into those possible actions
        for a batch and returns sequences of action strings. We assume ``action_indices`` is a dict
        mapping batch indices to k-best decoded sequence lists.
        """
        all_action_strings: List[List[List[str]]] = []
        batch_size = len(possible_actions)
        for i in range(batch_size):
            batch_actions = possible_actions[i]
            batch_best_sequences = action_indices[
                i] if i in action_indices else []
            # This will append an empty list to ``all_action_strings`` if ``batch_best_sequences``
            # is empty.
            action_strings = [
                batch_actions[rule_id][0] for rule_id in batch_best_sequences
            ]
            all_action_strings.append(action_strings)
        return all_action_strings
Ejemplo n.º 36
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()
        max_turn_length = 12

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim,
                                                       'x,y,x*y')
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,
                                                     self._encoding_dim,
                                                     'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            p1_answer_marker: torch.IntTensor = None,
            p2_answer_marker: torch.IntTensor = None,
            p3_answer_marker: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            followup_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            repeated_encoded_passage, passage_question_vectors,
            repeated_encoded_passage * passage_question_vectors,
            repeated_encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'yesno': self._span_yesno_accuracy.get_metric(reset),
            'followup': self._span_followup_accuracy.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Ejemplo n.º 37
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(phrase_layer.get_input_dim(),
                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'followup': self._span_followup_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
class SpansText2SqlParser(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    utterance_embedder : ``TextFieldEmbedder``
        Embedder for utterances.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input utterance.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    input_attention: ``Attention``
        We compute an attention over the input utterance at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    span_extractor: ``SpanExtractor``, optional
        If provided, extracts spans representations based on the encoded inputs.
        The span representations are used for decoding.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 mydatabase: str,
                 schema_path: str,
                 utterance_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 input_attention: Attention,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 span_extractor: SpanExtractor = None) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)
        # span extractor, allows using spans from the source as input to the decoder
        self._span_extractor = span_extractor
        self._exact_match = Average()
        self._action_similarity = Average()

        self._valid_sql_query = SqlValidity(mydatabase=mydatabase)
        self._token_match = TokenSequenceAccuracy()
        self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path)
        self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path)

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout)
        self.parse_sql_on_decoding = True
        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            action_sequence: torch.LongTensor = None,
            spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        spans: torch.Tensor, optional (default=None)
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of input spans that could be informative for the decoder. Comes from a ``ListField[SpanField]``
        """
        encode_outputs = self._encode(tokens, spans)
        # encode_outputs['mask'] shape: (batch_size, num_tokens, encoder_output_dim)
        batch_size = encode_outputs['mask'].size(0)
        initial_state = self._get_initial_state(
            encode_outputs['encoder_outputs'], encode_outputs['mask'],
            valid_actions)
        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            try:
                loss_output = self._decoder_trainer.decode(
                    initial_state, self._transition_function,
                    (action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
            except ZeroDivisionError as e:
                logger.info(
                    f"Input utterance in ZeroDivisionError: {[t.text for t in tokens['tokens']]}"
                )
                raise e

            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['target_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Add the target sql from the target actions for sql tokens exact match comparison
                target_sql_query = ''
                if action_sequence is not None:
                    target_action_strings = [
                        action_mapping[i][action_index]
                        for action_index in action_sequence[i].data.tolist()
                        if action_index != self._action_padding_index
                    ]
                    target_sql_query = action_sequence_to_sql(
                        target_action_strings)
                    # target_sql_query = sqlparse.format(target_sql_query, reindent=True)
                target_sql_query_for_acc = target_sql_query.split()

                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._action_similarity(0)
                    outputs['target_sql_query'].append(
                        target_sql_query_for_acc)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[i][action_index]
                    for action_index in best_action_indices
                ]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                predicted_sql_query_for_acc = predicted_sql_query.split()
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                    # predicted_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                                predicted_sql_query.split()]
                    # target_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in
                    #                             target_sql_query.split()]

                    predicted_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", predicted_sql_query).split()
                    target_sql_query_for_acc = re.sub(
                        r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ",
                        r" \g<1> AS \g<1>\g<2> ", target_sql_query).split()

                    self._valid_sql_query([predicted_sql_query_for_acc],
                                          [target_sql_query_for_acc])
                    self._token_match([predicted_sql_query_for_acc],
                                      [target_sql_query_for_acc])
                    self._kb_match([predicted_sql_query_for_acc],
                                   [target_sql_query_for_acc])
                    self._schema_free_match([predicted_sql_query_for_acc],
                                            [target_sql_query_for_acc])

                outputs['best_action_sequence'].append(action_strings)
                # outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['predicted_sql_query'].append(
                    predicted_sql_query_for_acc)
                outputs['target_sql_query'].append(target_sql_query_for_acc)
                outputs['debug_info'].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs

    def _encode(self,
                tokens: Dict[str, torch.LongTensor],
                spans: torch.Tensor = None):
        """
        If spans are provided, returns the encoded spans (by self._span_extractor) instead of the
        encoded utterance tokens
        """
        outputs = {}
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        outputs['mask'] = mask
        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance,
                                                      mask))
        outputs['encoder_outputs'] = encoder_outputs
        # if spans (over the input) are given, return their representation instead of the
        # source tokens representation
        if spans is not None and self._span_extractor is not None:
            # Looking at the span start index is enough to know if
            # this is padding or not. Shape: (batch_size, num_spans)
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
            span_representations = self._span_extractor(
                encoder_outputs, spans, mask, span_mask)
            outputs["mask"] = span_mask
            outputs["encoder_outputs"] = span_representations
        return outputs

    def _get_initial_state(
            self, encoder_outputs: torch.Tensor, mask: torch.Tensor,
            actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]
        initial_sql_state = [
            SqlStatelet(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            debug_info=None)
        return initial_state

    @staticmethod
    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return predicted_tensor.equal(targets_trimmed)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    @staticmethod
    def get_terminals_mask(action_strings):
        terminals_mask = []
        for j, rule in enumerate(action_strings):
            lhs, rhs = rule.split('->')
            rhs_values = rhs.strip().strip('[]').split(',')
            if len(rhs_values) == 1 and rhs_values[0].strip().strip(
                    '"') != rhs_values[0].strip():
                terminals_mask.append(1)
            elif 'TABLE_PLACEHOLDER' in rhs:
                terminals_mask.append(1)
            else:
                terminals_mask.append(0)
        return terminals_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track four metrics here:

            1. exact_match, which is the percentage of the time that our best output action sequence
            matches the SQL query exactly.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that can be parsed. (make sure
            you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. valid_sql_query, which is the percentage of time that decoding actually produces a
            valid SQL query.  We might not produce a valid SQL query if the decoder gets
            into a repetitive loop, or we're trying to produce a super long SQL query and run
            out of time steps, or something.

            4. action_similarity, which is how similar the action sequence predicted is to the actual
            action sequence. This is basically a soft measure of exact_match.
        """

        validation_correct = self._exact_match._total_value  # pylint: disable=protected-access
        validation_total = self._exact_match._count  # pylint: disable=protected-access
        all_metrics = {
            '_exact_match_count':
            validation_correct,
            '_example_count':
            validation_total,
            'exact_match':
            self._exact_match.get_metric(reset),
            'sql_validity':
            self._valid_sql_query.get_metric(reset=reset)['sql_validity'],
            'action_similarity':
            self._action_similarity.get_metric(reset)
        }
        all_metrics.update(self._token_match.get_metric(reset=reset))
        all_metrics.update(self._kb_match.get_metric(reset=reset))
        all_metrics.update(self._schema_free_match.get_metric(reset=reset))
        return all_metrics

    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[
            ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append(
                    (action, i))
            else:
                raise ValueError(
                    "The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items(
        ):
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append(
                    (production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors,
                                                 dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)

                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append(
                            (action_mapping[batch_index][action], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['utterance_attention'] = action_debug_info.get(
                    'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Ejemplo n.º 39
0
class WikiTablesSemanticParser(Model):
    """
    A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question,
    and produces a logical form that answers the question when executed over the table.  The
    logical form is generated by a `type-constrained`, `transition-based` parser. This is an
    abstract class that defines most of the functionality related to the transition-based parser. It
    does not contain the implementation for actually training the parser. You may want to train it
    using a learning-to-search algorithm, in which case you will want to use
    ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the
    correct denotation, you will want to use ``WikiTablesMmlSemanticParser``.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables.
    """
    # pylint: disable=abstract-method
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        super().__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._executor = WikiTablesSempreExecutor(tables_directory)
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)


        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

    def _get_initial_rnn_and_grammar_state(self,
                                           question: Dict[str, torch.LongTensor],
                                           table: Dict[str, torch.LongTensor],
                                           world: List[WikiTablesWorld],
                                           actions: List[List[ProductionRule]],
                                           outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet],
                                                                             List[LambdaGrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state

    @staticmethod
    def _get_neighbor_indices(worlds: List[WikiTablesWorld],
                              num_entities: int,
                              tensor: torch.Tensor) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(neighbor_indexes,
                                                      num_entities,
                                                      lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        return tensor.new_tensor(batch_neighbors, dtype=torch.long)

    @staticmethod
    def _get_type_vector(worlds: List[WikiTablesWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[WikiTablesWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. dpd_acc, which is the percentage of the time that our best output action sequence is
            in the set of action sequences provided by DPD.  This is an easy-to-compute lower bound
            on denotation accuracy for the set of examples where we actually have DPD output.  We
            only score dpd_acc on that subset.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that has DPD output (make sure
            you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        return {
                'dpd_acc': self._action_sequence_accuracy.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'lf_percent': self._has_logical_form.get_metric(reset),
                }

    def _create_grammar_state(self,
                              world: WikiTablesWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> LambdaGrammarStatelet:
        """
        This method creates the LambdaGrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias], dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding, output_embedding, action_id)

        return LambdaGrammarStatelet([START_SYMBOL],
                                     {},
                                     translated_valid_actions,
                                     context_actions,
                                     type_declaration.is_nonterminal)

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRule]],
                                    best_final_states: Mapping[int, Sequence[GrammarBasedState]],
                                    world: List[WikiTablesWorld],
                                    example_lisp_string: List[str],
                                    metadata: List[Dict[str, Any]],
                                    outputs: Dict[str, Any]) -> None:
        """
        Does common things for validation time: computing logical form accuracy (which is expensive
        and unnecessary during training), adding visualization info to the output dictionary, etc.

        This doesn't return anything; instead it `modifies` the given ``outputs`` dictionary, and
        calls metrics on ``self``.
        """
        batch_size = len(actions)
        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]
        outputs['action_mapping'] = action_mapping
        outputs['best_action_sequence'] = []
        outputs['debug_info'] = []
        outputs['entities'] = []
        outputs['logical_form'] = []
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = best_final_states[i][0].action_history[0]
                action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                try:
                    logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    self._has_logical_form(1.0)
                except ParsingError:
                    self._has_logical_form(0.0)
                    logical_form = 'Error producing logical form'
                if example_lisp_string:
                    denotation_correct = self._executor.evaluate_logical_form(logical_form,
                                                                              example_lisp_string[i])
                    self._denotation_accuracy(1.0 if denotation_correct else 0.0)
                outputs['best_action_sequence'].append(action_strings)
                outputs['logical_form'].append(logical_form)
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                outputs['entities'].append(world[i].table_graph.entities)
            else:
                outputs['logical_form'].append('')
                self._has_logical_form(0.0)
                self._denotation_accuracy(0.0)
        if metadata is not None:
            outputs["question_tokens"] = [x["question_tokens"] for x in metadata]
            outputs["original_table"] = [x["original_table"] for x in metadata]

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in the ``TransitionFunction``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
class PassageAttnToCount(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        passage_attention_to_count: Seq2SeqEncoder,
        dropout: float = 0.2,
        initializers: InitializerApplicator = InitializerApplicator()
    ) -> None:

        super(PassageAttnToCount, self).__init__(vocab=vocab)

        self.scaling_vals = [1, 2, 5, 10]

        self.passage_attention_to_count = passage_attention_to_count

        assert len(self.scaling_vals
                   ) == self.passage_attention_to_count.get_input_dim()

        self.num_counts = 10
        # self.passage_count_predictor = torch.nn.Linear(self.passage_attention_to_count.get_output_dim(),
        #                                                self.num_counts, bias=False)

        # We want to predict a score for each passage token
        self.passage_count_hidden2logits = torch.nn.Linear(
            self.passage_attention_to_count.get_output_dim(), 1, bias=True)

        self.passagelength_to_bias = torch.nn.Linear(1, 1, bias=True)

        self.count_acc = Average()

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializers(self)
        # self.passage_count_hidden2logits.bias.data.fill_(-1.0)
        # self.passage_count_hidden2logits.bias.requires_grad = False

    def device_id(self):
        allenutil.get_device_of()

    @overrides
    def forward(
            self,
            passage_attention: torch.Tensor,
            passage_lengths: List[int],
            count_answer: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        device_id = allenutil.get_device_of(passage_attention)

        batch_size, max_passage_length = passage_attention.size()

        # Shape: (B, passage_length)
        passage_mask = (passage_attention >= 0).float()

        # List of (B, P) shaped tensors
        scaled_attentions = [
            passage_attention * sf for sf in self.scaling_vals
        ]
        # Shape: (B, passage_length, num_scaling_factors)
        scaled_passage_attentions = torch.stack(scaled_attentions, dim=2)

        # Shape (batch_size, 1)
        passage_len_bias = self.passagelength_to_bias(
            passage_mask.sum(1, keepdim=True))

        scaled_passage_attentions = scaled_passage_attentions * passage_mask.unsqueeze(
            2)

        # Shape: (B, passage_length, hidden_dim)
        count_hidden_repr = self.passage_attention_to_count(
            scaled_passage_attentions, passage_mask)

        # Shape: (B, passage_length, 1) -- score for each token
        passage_span_logits = self.passage_count_hidden2logits(
            count_hidden_repr)
        # Shape: (B, passage_length) -- sigmoid on token-score
        token_sigmoids = torch.sigmoid(passage_span_logits.squeeze(2))
        token_sigmoids = token_sigmoids * passage_mask

        # Shape: (B, 1) -- sum of sigmoids. This will act as the predicted mean
        # passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) + passage_len_bias
        passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True)

        # Shape: (1, count_vals)
        self.countvals = allenutil.get_range_vector(
            10, device=device_id).unsqueeze(0).float()

        variance = 0.2

        # Shape: (batch_size, count_vals)
        l2_by_vsquared = torch.pow(self.countvals - passage_count_mean,
                                   2) / (2 * variance * variance)
        exp_val = torch.exp(-1 * l2_by_vsquared) + 1e-30
        # Shape: (batch_size, count_vals)
        count_distribution = exp_val / (torch.sum(exp_val, 1, keepdim=True))

        # Loss computation
        output_dict = {}
        loss = 0.0
        pred_count_idx = torch.argmax(count_distribution, 1)
        if count_answer is not None:
            # L2-loss
            passage_count_mean = passage_count_mean.squeeze(1)
            L2Loss = F.mse_loss(input=passage_count_mean,
                                target=count_answer.float())
            loss = L2Loss
            predictions = passage_count_mean.detach().cpu().numpy()
            predictions = np.round_(predictions)

            gold_count = count_answer.detach().cpu().numpy()
            correct_vec = (predictions == gold_count)
            correct_perc = sum(correct_vec) / batch_size
            # print(f"{correct_perc} {predictions} {gold_count}")
            self.count_acc(correct_perc)

            # loss = F.cross_entropy(input=count_distribution, target=count_answer)
            # List of predicted count idxs, Shape: (B,)
            # correct_vec = (pred_count_idx == count_answer).float()
            # correct_perc = torch.sum(correct_vec) / batch_size
            # self.count_acc(correct_perc.item())

        batch_loss = loss / batch_size
        output_dict["loss"] = batch_loss
        output_dict["passage_attention"] = passage_attention
        output_dict["passage_sigmoid"] = token_sigmoids
        output_dict["count_mean"] = passage_count_mean
        output_dict["count_distritbuion"] = count_distribution
        output_dict["count_answer"] = count_answer
        output_dict["pred_count"] = pred_count_idx

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}
        count_acc = self.count_acc.get_metric(reset)
        metric_dict.update({'acc': count_acc})

        return metric_dict
Ejemplo n.º 41
0
class SimpleProjectionOld(Model):
    """

    """
    def __init__(self,
                 vocab: Vocabulary,
                 input_embedder: TextFieldEmbedder,
                 pooler: Seq2VecEncoder,
                 nli_projection_layer: FeedForward,
                 training_tasks: Any,
                 validation_tasks: Any,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SimpleProjectionOld, self).__init__(vocab, regularizer)
        if type(training_tasks) == dict:
            self._training_tasks = list(training_tasks.keys())
        else:
            self._training_tasks = training_tasks

        if type(validation_tasks) == dict:
            self._validation_tasks = list(validation_tasks.keys())
        else:
            self._validation_tasks = validation_tasks

        self._input_embedder = input_embedder
        self._pooler = pooler

        self._label_namespace = "labels"
        self._num_labels = vocab.get_vocab_size(
            namespace=self._label_namespace)

        self._nli_projection_layer = nli_projection_layer
        print(
            vocab.get_token_to_index_vocabulary(
                namespace=self._label_namespace))
        assert nli_projection_layer.get_output_dim() == self._num_labels

        self._dropout = torch.nn.Dropout(p=dropout)

        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self._nli_projection_layer)

        self._nli_per_lang_acc: Dict[str, CategoricalAccuracy] = dict()

        for taskname in self._validation_tasks:
            # this will hide some metrics from tqdm, but they will still be computed
            self._nli_per_lang_acc[taskname] = CategoricalAccuracy()
        self._nli_avg_acc = Average()

    def forward(
            self,  # type: ignore
            premise_hypothesis: Dict[str, torch.Tensor] = None,
            premise: Dict[str, torch.Tensor] = None,
            hypothesis: Dict[str, torch.LongTensor] = None,
            dataset: List[str] = None,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise_hypothesis : Dict[str, torch.LongTensor]
            Combined in a single text field for BERT encoding
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        dataset : List[str]
            Task indicator
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        if dataset is not None:  # TODO: hardcoded; used when not multitask reader was used
            taskname = dataset[0]
        else:
            taskname = "nli-en"

        if premise_hypothesis is not None:
            assert premise is None and hypothesis is None

        if premise_hypothesis is not None:
            embedded_combined = self._input_embedder(premise_hypothesis)
            mask = get_text_field_mask(premise_hypothesis).float()
            pooled_combined = self._pooler(embedded_combined, mask=mask)
        elif premise is not None and hypothesis is not None:
            embedded_premise = self._input_embedder(premise)
            embedded_hypothesis = self._input_embedder(hypothesis)

            mask_premise = get_text_field_mask(premise).float()
            mask_hypothesis = get_text_field_mask(hypothesis).float()

            pooled_premise = self._pooler(embedded_premise, mask=mask_premise)
            pooled_hypothesis = self._pooler(embedded_hypothesis,
                                             mask=mask_hypothesis)

            pooled_combined = torch.cat([pooled_premise, pooled_hypothesis],
                                        dim=-1)
        else:
            raise ConfigurationError(
                "One of premise or hypothesis is None. Check your DatasetReader"
            )

        pooled_combined = self._dropout(pooled_combined)

        logits = self._nli_projection_layer(pooled_combined)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss

            self._nli_per_lang_acc[taskname](logits, label)

        if metadata is not None:
            output_dict["premise_tokens"] = [
                x["premise_tokens"] for x in metadata
            ]
            output_dict["hypothesis_tokens"] = [
                x["hypothesis_tokens"] for x in metadata
            ]

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add ``"label"`` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [
                predictions[i] for i in range(predictions.shape[0])
            ]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = (self.vocab.get_index_to_token_vocabulary(
                self._label_namespace).get(label_idx, str(label_idx)))
            classes.append(label_str)
        output_dict["label"] = classes
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}

        if self.training:
            tasks = self._training_tasks
        else:
            tasks = self._validation_tasks

        for taskname in tasks:
            metricname = taskname
            if metricname[-2:] != 'en' or metricname[-2:] != 'de' or metricname[
                    -2:] != 'ru':  # hide other langs from tqdn
                metricname = '_' + metricname
            metrics[metricname] = self._nli_per_lang_acc[taskname].get_metric(
                reset)

        accs = metrics.values()  # TODO: should only count 'nli-*' metrics
        avg = sum(accs) / sum(x > 0 for x in accs)
        self._nli_avg_acc(avg)
        metrics["nli-avg"] = self._nli_avg_acc.get_metric(reset)

        return metrics