def test_tie_break_categorical_accuracy(self):
        accuracy = CategoricalAccuracy(tie_break=True)
        predictions = torch.Tensor([[0.35, 0.25, 0.35, 0.35, 0.35],
                                    [0.1, 0.6, 0.1, 0.2, 0.2],
                                    [0.1, 0.0, 0.1, 0.2, 0.2]])
        # Test without mask:
        targets = torch.Tensor([2, 1, 4])
        accuracy(predictions, targets)
        assert accuracy.get_metric(reset=True) == (0.25 + 1 + 0.5)/3.0

        # # # Test with mask
        mask = torch.Tensor([1, 0, 1])
        targets = torch.Tensor([2, 1, 4])
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric(reset=True) == (0.25 + 0.5)/2.0

        # # Test tie-break with sequence
        predictions = torch.Tensor([[[0.35, 0.25, 0.35, 0.35, 0.35],
                                     [0.1, 0.6, 0.1, 0.2, 0.2],
                                     [0.1, 0.0, 0.1, 0.2, 0.2]],
                                    [[0.35, 0.25, 0.35, 0.35, 0.35],
                                     [0.1, 0.6, 0.1, 0.2, 0.2],
                                     [0.1, 0.0, 0.1, 0.2, 0.2]]])
        targets = torch.Tensor([[0, 1, 3],  # 0.25 + 1 + 0.5
                                [0, 3, 4]]) # 0.25 + 0 + 0.5 = 2.5
        accuracy(predictions, targets)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 2.5/6.0)
 def test_top_k_categorical_accuracy(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0]])
     targets = torch.Tensor([0, 3])
     accuracy(predictions, targets)
     actual_accuracy = accuracy.get_metric()
     assert actual_accuracy == 1.0
 def test_top_k_categorical_accuracy_respects_mask(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0],
                                 [0.1, 0.2, 0.5, 0.2, 0.0]])
     targets = torch.Tensor([0, 3, 0])
     mask = torch.Tensor([0, 1, 1])
     accuracy(predictions, targets, mask)
     actual_accuracy = accuracy.get_metric()
     assert actual_accuracy == 0.50
 def test_top_k_categorical_accuracy_accumulates_and_resets_correctly(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0]])
     targets = torch.Tensor([0, 3])
     accuracy(predictions, targets)
     accuracy(predictions, targets)
     accuracy(predictions, torch.Tensor([4, 4]))
     accuracy(predictions, torch.Tensor([4, 4]))
     actual_accuracy = accuracy.get_metric(reset=True)
     assert actual_accuracy == 0.50
     assert accuracy.correct_count == 0.0
     assert accuracy.total_count == 0.0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DecomposableAttention, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(),
                               "text field embedding dim", "attend feedforward input dim")
        check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels,
                               "final output dimension", "number of labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
        self.accuracy = CategoricalAccuracy()
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}
class CoLATask(Task):
    '''Class for Warstdadt acceptability task'''
    def __init__(self, path, max_seq_len, name="acceptability"):
        ''' '''
        super(CoLATask, self).__init__(name, 2)
        self.pair_input = 0
        self.load_data(path, max_seq_len)
        self.val_metric = "%s_accuracy" %
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = CategoricalAccuracy()

    def load_data(self, path, max_seq_len):
        '''Load the data'''
        tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len,
                           s1_idx=3, s2_idx=None, targ_idx=1)
        val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len,
                            s1_idx=3, s2_idx=None, targ_idx=1)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data"\tFinished loading CoLA.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'acc': self.scorer2.get_metric(reset)}
 def __init__(self, path, max_seq_len, name="acceptability"):
     ''' '''
     super(CoLATask, self).__init__(name, 2)
     self.pair_input = 0
     self.load_data(path, max_seq_len)
     self.val_metric = "%s_accuracy" %
     self.val_metric_decreases = False
     self.scorer1 = Average()
     self.scorer2 = CategoricalAccuracy()
 def __init__(self,
              word_embeddings: TextFieldEmbedder,
              encoder: Seq2SeqEncoder,
              vocab: Vocabulary) -> None:
     self.word_embeddings = word_embeddings
     self.encoder = encoder
     self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
     self.accuracy = CategoricalAccuracy()
    def test_top_k_categorical_accuracy_works_for_sequences(self):
        accuracy = CategoricalAccuracy(top_k=2)
        predictions = torch.Tensor([[[0.35, 0.25, 0.1, 0.1, 0.2],
                                     [0.1, 0.6, 0.1, 0.2, 0.0],
                                     [0.1, 0.6, 0.1, 0.2, 0.0]],
                                    [[0.35, 0.25, 0.1, 0.1, 0.2],
                                     [0.1, 0.6, 0.1, 0.2, 0.0],
                                     [0.1, 0.6, 0.1, 0.2, 0.0]]])
        targets = torch.Tensor([[0, 3, 4],
                                [0, 1, 4]])
        accuracy(predictions, targets)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0.6666666)

        # Test the same thing but with a mask:
        mask = torch.Tensor([[0, 1, 1],
                             [1, 0, 1]])
        accuracy(predictions, targets, mask)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0.50)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 span_extractor: SpanExtractor,
                 encoder: Seq2SeqEncoder,
                 feedforward_layer: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 evalb_directory_path: str = None) -> None:
        super(SpanConstituencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.span_extractor = span_extractor
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder
        self.feedforward_layer = TimeDistributed(feedforward_layer) if feedforward_layer else None
        self.pos_tag_embedding = pos_tag_embedding or None
        if feedforward_layer is not None:
            output_dim = feedforward_layer.get_output_dim()
            output_dim = span_extractor.get_output_dim()

        self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_classes))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
                               "representation dim (tokens + optional POS tags)",
                               "encoder input dim")
                               "encoder input dim",
                               "span extractor input dim")
        if feedforward_layer is not None:
                                   "span extractor output dim",
                                   "feedforward input dim")

        self.tag_accuracy = CategoricalAccuracy()

        if evalb_directory_path is not None:
            self._evalb_score = EvalbBracketingScorer(evalb_directory_path)
            self._evalb_score = None
 def __init__(self, name, n_classes): = name
     self.n_classes = n_classes
     self.train_data_text, self.val_data_text, self.test_data_text = \
         None, None, None
     self.train_data = None
     self.val_data = None
     self.test_data = None
     self.pred_layer = None
     self.pair_input = 1
     self.categorical = 1 # most tasks are
     self.val_metric = "%s_accuracy" %
     self.val_metric_decreases = False
     self.scorer1 = CategoricalAccuracy()
     self.scorer2 = None
 def __init__(self,
              #### The embedding layer is specified as an AllenNLP <code>TextFieldEmbedder</code> which represents a general way of turning tokens into tensors. (Here we know that we want to represent each unique word with a learned tensor, but using the general class allows us to easily experiment with different types of embeddings, for example <a href = "">ELMo</a>.)
              word_embeddings: TextFieldEmbedder,
              #### Similarly, the encoder is specified as a general <code>Seq2SeqEncoder</code> even though we know we want to use an LSTM. Again, this makes it easy to experiment with other sequence encoders, for example a Transformer.
              encoder: Seq2SeqEncoder,
              #### Every AllenNLP model also expects a <code>Vocabulary</code>, which contains the namespaced mappings of tokens to indices and labels to indices.
              vocab: Vocabulary) -> None:
     #### Notice that we have to pass the vocab to the base class constructor.
     self.word_embeddings = word_embeddings
     self.encoder = encoder
     #### The feed forward layer is not passed in as a parameter, but is constructed by us. Notice that it looks at the encoder to find the correct input dimension and looks at the vocabulary (and, in particular, at the label -> index mapping) to find the correct output dimension.
     self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
     #### The last thing to notice is that we also instantiate a <code>CategoricalAccuracy</code> metric, which we'll use to track accuracy during each training and validation epoch.
     self.accuracy = CategoricalAccuracy()
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 dropout: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(),
                               "encoder output dim", "projection feedforward input")
        check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(),
                               "proj feedforward output dim", "inference lstm input dim")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

class LstmTagger(Model):
    #### One thing that might seem unusual is that we're going pass in the embedder and the sequence encoder as constructor parameters. This allows us to experiment with different embedders and encoders without having to change the model code.
    def __init__(self,
                 #### The embedding layer is specified as an AllenNLP <code>TextFieldEmbedder</code> which represents a general way of turning tokens into tensors. (Here we know that we want to represent each unique word with a learned tensor, but using the general class allows us to easily experiment with different types of embeddings, for example <a href = "">ELMo</a>.)
                 word_embeddings: TextFieldEmbedder,
                 #### Similarly, the encoder is specified as a general <code>Seq2SeqEncoder</code> even though we know we want to use an LSTM. Again, this makes it easy to experiment with other sequence encoders, for example a Transformer.
                 encoder: Seq2SeqEncoder,
                 #### Every AllenNLP model also expects a <code>Vocabulary</code>, which contains the namespaced mappings of tokens to indices and labels to indices.
                 vocab: Vocabulary) -> None:
        #### Notice that we have to pass the vocab to the base class constructor.
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        #### The feed forward layer is not passed in as a parameter, but is constructed by us. Notice that it looks at the encoder to find the correct input dimension and looks at the vocabulary (and, in particular, at the label -> index mapping) to find the correct output dimension.
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
        #### The last thing to notice is that we also instantiate a <code>CategoricalAccuracy</code> metric, which we'll use to track accuracy during each training and validation epoch.
        self.accuracy = CategoricalAccuracy()

    #### Next we need to implement <code>forward</code>, which is where the actual computation happens. Each <code>Instance</code> in your dataset will get (batched with other instances and) fed into <code>forward</code>. The <code>forward</code> method expects dicts of tensors as input, and it expects their names to be the names of the fields in your <code>Instance</code>. In this case we have a sentence field and (possibly) a labels field, so we'll construct our <code>forward</code> accordingly:
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        #### AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. Here we just use the utility function <code>get_text_field_mask</code>, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations.
        mask = get_text_field_mask(sentence)
        #### We start by passing the <code>sentence</code> tensor (each sentence a sequence of token ids) to the <code>word_embeddings</code> module, which converts each sentence into a sequence of embedded tensors.
        embeddings = self.word_embeddings(sentence)
        #### We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs.
        encoder_out = self.encoder(embeddings, mask)
        #### Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags.
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}

        #### As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output.
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output

    #### We included an accuracy metric that gets updated each forward pass. That means we need to override a <code>get_metrics</code> method that pulls the data out of it. Behind the scenes, the <code>CategoricalAccuracy</code> metric is storing the number of predictions and the number of correct predictions, updating those counts during each call to forward. Each call to get_metric returns the calculated accuracy and (optionally) resets the counts, which is what allows us to track accuracy anew for each epoch.
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}
class Task():
    '''Abstract class for a task

    Methods and attributes:
        - load_data: load dataset from a path and create splits
        - yield dataset for training
        - dataset size
        - validate and test

    Outside the task:
        - process: pad and indexify data given a mapping
        - optimizer
    __metaclass__ = ABCMeta

    def __init__(self, name, n_classes): = name
        self.n_classes = n_classes
        self.train_data_text, self.val_data_text, self.test_data_text = \
            None, None, None
        self.train_data = None
        self.val_data = None
        self.test_data = None
        self.pred_layer = None
        self.pair_input = 1
        self.categorical = 1 # most tasks are
        self.val_metric = "%s_accuracy" %
        self.val_metric_decreases = False
        self.scorer1 = CategoricalAccuracy()
        self.scorer2 = None

    def load_data(self, path, max_seq_len):
        Load data from path and create splits.
        raise NotImplementedError

    def get_metrics(self, reset=False):
        '''Get metrics specific to the task'''
        acc = self.scorer1.get_metric(reset)
        return {'accuracy': acc}
class MultiHopAttentionQAFreezeDetRes101(Model):
    def __init__(
            vocab: Vocabulary,
            span_encoder: Seq2SeqEncoder,
            reasoning_encoder: Seq2SeqEncoder,
            input_dropout: float = 0.3,
            hidden_dim_maxpool: int = 1024,
            class_embs: bool = True,
            reasoning_use_obj: bool = True,
            reasoning_use_answer: bool = True,
            reasoning_use_question: bool = True,
            pool_reasoning: bool = True,
            pool_answer: bool = True,
            pool_question: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
        super(MultiHopAttentionQAFreezeDetRes101, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,

        # freeze everything related to conv net
        for submodule in self.detector.backbone.modules():
            # if isinstance(submodule, BatchNorm2d):
            # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        for submodule in self.detector.after_roi_align.modules():
            # if isinstance(submodule, BatchNorm2d):
            # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        self.rnn_input_dropout = TimeDistributed(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)

        self.span_attention = BilinearMatrixAttention(

        self.obj_attention = BilinearMatrixAttention(

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

    def _collect_obj_reps(self, span_tags, object_reps):
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                               *span_tags_fixed.shape, -1)

    def embed_span(self, span, span_tags, span_mask, object_reps):
        :param span: Thing that will get embed and turned into [batch_size, ..leading_dims.., L, word_dim]
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :param span_mask: [batch_size, ..leading_dims.., span_mask
        retrieved_feats = self._collect_obj_reps(span_tags, object_reps)

        span_rep =['bert'], retrieved_feats), -1)
        # add recurrent dropout here
        if self.rnn_input_dropout:
            span_rep = self.rnn_input_dropout(span_rep)

        return self.span_encoder(span_rep, span_mask), retrieved_feats

    def forward(self,
                images: torch.Tensor,
                objects: torch.LongTensor,
                segms: torch.Tensor,
                boxes: torch.Tensor,
                box_mask: torch.LongTensor,
                question: Dict[str, torch.Tensor],
                question_tags: torch.LongTensor,
                question_mask: torch.LongTensor,
                answers: Dict[str, torch.Tensor],
                answer_tags: torch.LongTensor,
                answer_mask: torch.LongTensor,
                metadata: List[Dict[str, Any]] = None,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        :param images: [batch_size, 3, im_height, im_width]
        :param objects: [batch_size, max_num_objects] Padded objects
        :param boxes:  [batch_size, max_num_objects, 4] Padded boxes
        :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK
        :param question: AllenNLP representation of the question. [batch_size, num_answers, seq_length]
        :param question_tags: A detection label for each item in the Q [batch_size, num_answers, seq_length]
        :param question_mask: Mask for the Q [batch_size, num_answers, seq_length]
        :param answers: AllenNLP representation of the answer. [batch_size, num_answers, seq_length]
        :param answer_tags: A detection label for each item in the A [batch_size, num_answers, seq_length]
        :param answer_mask: Mask for the As [batch_size, num_answers, seq_length]
        :param metadata: Ignore, this is about which dataset item we're on
        :param label: Optional, which item is valid
        :return: shit
        # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are
        # not needed
        max_len = int(box_mask.sum(1).max().item())
        objects = objects[:, :max_len]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        segms = segms[:, :max_len]

        for tag_type, the_tags in (('question', question_tags), ('answer',
            if int(the_tags.max()) > max_len:
                raise ValueError(
                    "Oh no! {}_tags has maximum of {} but objects is of dim {}. Values are\n{}"
                    .format(tag_type, int(the_tags.max()), objects.shape,

        obj_reps = self.detector(images=images,

        # Now get the question representations
        q_rep, q_obj_reps = self.embed_span(question, question_tags,
        a_rep, a_obj_reps = self.embed_span(answers, answer_tags, answer_mask,

        # Perform Q by A attention
        # [batch_size, 4, question_length, answer_length]
        qa_similarity = self.span_attention(
            q_rep.view(q_rep.shape[0] * q_rep.shape[1], q_rep.shape[2],
            a_rep.view(a_rep.shape[0] * a_rep.shape[1], a_rep.shape[2],
        ).view(a_rep.shape[0], a_rep.shape[1], q_rep.shape[2], a_rep.shape[2])
        qa_attention_weights = masked_softmax(qa_similarity,
                                              question_mask[..., None],
        attended_q = torch.einsum('bnqa,bnqd->bnad',
                                  (qa_attention_weights, q_rep))

        # Have a second attention over the objects, do A by Objs
        # [batch_size, 4, answer_length, num_objs]
        atoo_similarity = self.obj_attention(
            a_rep.view(a_rep.shape[0], a_rep.shape[1] * a_rep.shape[2], -1),
            obj_reps['obj_reps']).view(a_rep.shape[0], a_rep.shape[1],
        atoo_attention_weights = masked_softmax(atoo_similarity,
                                                box_mask[:, None, None])
        attended_o = torch.einsum(
            'bnao,bod->bnad', (atoo_attention_weights, obj_reps['obj_reps']))

        reasoning_inp =[
            x for x, to_pool in [(a_rep, self.reasoning_use_answer
                                  ), (attended_o, self.reasoning_use_obj),
                                 (attended_q, self.reasoning_use_question)]
            if to_pool
        ], -1)

        if self.rnn_input_dropout is not None:
            reasoning_inp = self.rnn_input_dropout(reasoning_inp)
        reasoning_output = self.reasoning_encoder(reasoning_inp, answer_mask)

        things_to_pool =[
            for x, to_pool in [(reasoning_output,
                                self.pool_reasoning), (a_rep,
                               (attended_q, self.pool_question)] if to_pool
        ], -1)

        # [batch_size,4,answer_length, answer_length,512*n]
        pooled_rep = replace_masked_values(things_to_pool, answer_mask[...,
        logits = self.final_mlp(pooled_rep).squeeze(2)


        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {
            "label_logits": logits,
            "label_probs": class_probabilities,
            'cnn_regularization_loss': obj_reps['cnn_regularization_loss'],
            # Uncomment to visualize attention, if you want
            # 'qa_attention_weights': qa_attention_weights,
            # 'atoo_attention_weights': atoo_attention_weights,
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label)
            output_dict["loss"] = loss[None]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
class DialogQA(Model):
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")


        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask =, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question =[embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage =[repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage =[repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage =[repeated_embedded_passage, p3_answer_marker_emb],

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage =[repeated_encoded_passage,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs =[self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder([final_merged_passage, start_rep], dim=-1),
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc =

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end =

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]

                best_span_string = passage_str[start_offset:end_offset]
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            refs = [answer_texts[z] for z in idxes]
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                self._official_f1(100 * f1_score)
        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'followup': self._span_followup_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long)

        span_start_logits =
        span_end_logits =
        span_yesno_logits =
        span_followup_logits =
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
 def __init__(self,
              vocab: Vocabulary,
              regularizer: Optional[RegularizerApplicator] = None) -> None:
     super().__init__(vocab, regularizer)
     self.metrics = {'accuracy': CategoricalAccuracy()}
class CommonsenseQATask(MultipleChoiceTask):
    """ Task class for CommonsenseQA Task.  """
    def __init__(self, path, max_seq_len, name, easy=False, **kw):
        super().__init__(name, **kw)
        self.path = path
        self.max_seq_len = max_seq_len

        self.easy = easy
        self.train_data_text = None
        self.val_data_text = None
        self.test_data_text = None

        self.scorer1 = CategoricalAccuracy()
        self.scorers = [self.scorer1]
        self.val_metric = "%s_accuracy" % name
        self.val_metric_decreases = False
        self.n_choices = 5
        self.label2choice_idx = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
        self.choice_idx2label = ["A", "B", "C", "D", "E"]

    def load_data(self):
        """ Process the dataset located at path.  """
        def _load_split(data_file):
            questions, choices, targs, id_str = [], [], [], []
            data = [json.loads(l) for l in open(data_file, encoding="utf-8")]
            for example in data:
                question = tokenize_and_truncate(
                    self._tokenizer_name, "Q:" + example["question"]["stem"],
                choices_dict = {
                                          "A:" + a_choice["text"],
                    for a_choice in example["question"]["choices"]
                multiple_choices = [
                    choices_dict[label] for label in self.choice_idx2label
                targ = self.label2choice_idx[
                    example["answerKey"]] if "answerKey" in example else 0
                example_id = example["id"]
            return [questions, choices, targs, id_str]

        train_file = "train_rand_split_EASY.jsonl" if self.easy else "train_rand_split.jsonl"
        val_file = "dev_rand_split_EASY.jsonl" if self.easy else "dev_rand_split.jsonl"
        test_file = "test_rand_split_no_answers.jsonl"
        self.train_data_text = _load_split(os.path.join(self.path, train_file))
        self.val_data_text = _load_split(os.path.join(self.path, val_file))
        self.test_data_text = _load_split(os.path.join(self.path, test_file))
        self.sentences = (self.train_data_text[0] + self.val_data_text[0] + [
            choice for choices in self.train_data_text[1] for choice in choices
        ] + [
            choice for choices in self.val_data_text[1] for choice in choices
        ])"\tFinished loading CommonsenseQA data.")

    def process_split(
            self, split, indexers,
            model_preprocessing_interface) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """
        def _make_instance(question, choices, label, id_str):
            d = {}
            d["question_str"] = MetadataField(" ".join(question))
            if not model_preprocessing_interface.model_flags[
                d["question"] = sentence_to_text_field(
            for choice_idx, choice in enumerate(choices):
                inp = (model_preprocessing_interface.boundary_token_fn(
                    question, choice) if model_preprocessing_interface.
                       model_flags["uses_pair_embedding"] else
                d["choice%d" % choice_idx] = sentence_to_text_field(
                    inp, indexers)
                d["choice%d_str" % choice_idx] = MetadataField(
                    " ".join(choice))
            d["label"] = LabelField(label,
            d["id_str"] = MetadataField(id_str)
            return Instance(d)

        split = list(split)
        instances = map(_make_instance, *split)
        return instances

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        acc = self.scorer1.get_metric(reset)
        return {"accuracy": acc}
class QuarelSemanticParser(Model):
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
   += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
            # Rest of init not needed for denotation only where no decoding to actions needed

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input =[embedded_question, entity_bits], 2)
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs =[encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some ``s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],

            outputs = self._decoder_trainer.decode(initial_state,
                                                   (target_action_sequences, target_mask))
            return outputs

            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states =,
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_type_vector(worlds: List[QuarelWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                    entity_type = 0

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[QuarelWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:

                if len(entity_indices) == 1:
                    # No entities of this type; move along...

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)

            # (num_question_tokens, num_entities)
            probabilities =, dim=1)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0/self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        if self._denotation_only:
            metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)}
            metrics = {
                    'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                    'denotation_acc': self._denotation_accuracy.get_metric(reset),
                    'lf_percent': self._has_logical_form.get_metric(reset),
        return metrics

    def _create_grammar_state(self,
                              world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor =, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings =[global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,

        return GrammarStatelet([START_SYMBOL],

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
class BidirectionalAttentionFlow(Model):
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        similarity_function: SimilarityFunction,
        modeling_layer: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        dropout: float = 0.2,
        mask_lstms: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
            4 * encoding_dim,
            "modeling layer input dim",
            "4 * encoding dim",
            "text field embedder output dim",
            "phrase layer input dim",
            4 * encoding_dim + 3 * modeling_dim,
            "span end encoder input dim",
            "4 * encoding dim + 3 * modeling dim",

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms


    def forward(  # type: ignore
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        embedded_question = self._highway_layer(
        embedded_passage = self._highway_layer(
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage =
                encoded_passage * passage_question_vectors,
                encoded_passage * tiled_question_passage_vector,

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
  [final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation =
                modeled_passage * tiled_start_representation,
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
  [final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,[span_start, span_end],
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            token_offsets = []
            for i in range(batch_size):
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
            output_dict["token_offsets"] = token_offsets
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,

    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = (torch.triu(
            torch.ones((passage_length, passage_length),
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
class QaNet(Model):
    This class implements Adams Wei Yu's `QANet Model <>`_
    for machine reading comprehension published at ICLR 2018.

    The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet
    replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the
    modeling layer and output layer.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the passage-question attention.
    matrix_attention_layer : ``MatrixAttention``
        The matrix attention function that we will use when comparing encoded passage and question
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    dropout_prob : ``float``, optional (default=0.1)
        If greater than 0, we will apply dropout with this probability between layers.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        matrix_attention_layer: MatrixAttention,
        modeling_layer: Seq2SeqEncoder,
        dropout_prob: float = 0.1,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim,
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim,
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4,
        self._modeling_layer = modeling_layer

        self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
        self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._metrics = SquadEmAndF1()
        self._dropout = torch.nn.Dropout(
            p=dropout_prob) if dropout_prob > 0 else lambda x: x


    def forward(  # type: ignore
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
        embedded_passage = self._highway_layer(

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
        projected_embedded_passage = self._encoding_proj_layer(

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention,
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage,

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,

        modeled_passage_list = [

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input =
            [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input =
            [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,[span_start, span_end],
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,
    def __init__(self,
                 params: Params,
                 vocab: Vocabulary) -> None:

        enc_hidden_dim = params.pop_int('enc_hidden_dim', 300)
        gen_hidden_dim = params.pop_int('gen_hidden_dim', 300)
        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice(
            'code_dist_type', ['gaussian', 'vmf'],
        code_dim = params.pop_int('code_dim', 300)
        tie_embedding = params.pop_bool('tie_embedding', False)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self._l2_weight = l2_weight

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._encoder = PytorchSeq2VecWrapper(
                    hidden_size=enc_hidden_dim, batch_first=True))
        self._generator = PytorchSeq2SeqWrapper(
                                + code_dim),
                    hidden_size=gen_hidden_dim, batch_first=True))
        self._generator_projector = nn.Linear(
        if tie_embedding:
            self._generator_projector.weight = self._token_embedder.weight

        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(
                code_dim=code_dim, kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(
            raise ValueError('Unknown code_dist_type')

        self._discriminator = FeedForward(
            input_dim=2 * self._code_generator.get_output_dim(),
            hidden_dims=[disc_hidden_dim]*disc_num_layers + [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers
                        + [Activation.by_name('linear')()],

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float(
            'discriminator_weight', 0.1)
        self._gumbel_temperature = 1.0

        # Metrics
        self._metrics = {
            'generator_loss': ScalarMetric(),
            'kl_divergence': ScalarMetric(),
            'discriminator_accuracy': CategoricalAccuracy(),
            'discriminator_loss': ScalarMetric(),
            'loss': ScalarMetric()
class MLMBaseline(Model):
    This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
    by Chen et al., 2017.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 dropout: float = 0.5,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder

        self._BERTLikeLMModelHead = BERTLikeLMModelHead(1000, vocab.get_vocab_size())

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()


    def forward(self,  # type: ignore
                phrase: Dict[str, torch.LongTensor],
                choices: List[Dict[str, torch.LongTensor]],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        batch_size, num_of_tokens = phrase['tokens'].shape
        choices_ids = choices['tokens'][:,:,0].squeeze()
        _, num_choices = choices_ids.shape
        embedded_phrase = self._text_field_embedder(phrase)

        # putting the batch_size first, and concating the embeddings (the permute is to make
        # sure order is perserved
        embedded_phrase = embedded_phrase.view(batch_size, -1)
        # zero padding to reach the exact classifier size

        # now we expand to size (7, 11) by appending a row of 0s at pos 0 and pos 6,
        # and a column of 0s at pos 10
        embedded_phrase = F.pad(input=embedded_phrase, pad=(0, 1000 - num_of_tokens*50), mode='constant', value=0)

        # applying the 2 layes MLP
        loss, label_logits = self._BERTLikeLMModelHead(embedded_phrase, choices_ids, label)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        self._accuracy(label_logits, label)
        output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric = self._accuracy.get_metric(reset)
        return {
            'accuracy': metric,
            'EM': metric
class SNLISiameseModel(Model):
    This is a simple Siamese model to be trained on SNLI dataset, the main purpose is
        to transfer the trained weights to the RAHP model

    Network struture:
        embed - encode - cancatenate 2 vecs - make prediction
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 context_encoder: Seq2SeqEncoder,
                 text_encoder_entailment: Seq2VecEncoder,
                 matching_layer: FeedForward,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None
        super(SNLISiameseModel, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.context_encoder = context_encoder
        self.text_encoder_entailment = text_encoder_entailment
        self.matching_layer = matching_layer

        # running value, reset the values every epoch
        self.accuracy = CategoricalAccuracy()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, premise, hypothesis, label = None):
        # shape = (batch_size, seq_len)
        p_mask = get_text_field_mask(premise)
        h_mask = get_text_field_mask(hypothesis)
        # shape = (batch_size, seq_len, embed_dim)
        embedded_p = self.text_field_embedder(premise)
        embedded_h = self.text_field_embedder(hypothesis)
        # context encoder
        context_p = self.context_encoder(embedded_p, p_mask)
        context_h = self.context_encoder(embedded_h, h_mask)
        # inference encoder: encode to fixed-size vectors
        o_p = self.text_encoder_entailment(context_p, p_mask)
        o_h = self.text_encoder_entailment(context_h, h_mask)
        # feed to FC layer
        logits = self.matching_layer([o_p, o_h], dim=-1))
        probs = F.softmax(logits, dim=-1)
        output_dict = {'probs': probs}

        if label is not None:
            loss = self.criterion(logits, label.long().view(-1))
            self.accuracy(logits, label) 
            output_dict['loss'] = loss
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}
        metrics["accuracy"] = self.accuracy.get_metric(reset)
        return metrics
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str,
                 initializer: InitializerApplicator = InitializerApplicator(),
        super(GeneralGenerationForClassfiication, self).__init__(vocab)
        self.gen_model = XLNetLMHeadModel.from_pretrained(model_name,
        self.tokenizer = XLNetTokenizer.from_pretrained(model_name)
        self.gen_word_embedding = self.gen_model.transformer.word_embedding
        self.gen_embeddings_weight = self.gen_word_embedding.weight

        if use_cls:
            self.cls_model = XLNetModel.from_pretrained(model_name)
            self.cls_word_embedding = self.cls_model.word_embedding
            self.cls_embeddings_weight = self.cls_word_embedding.weight
        if use_kld_loss:
            self.freezed_lm = XLNetLMHeadModel.from_pretrained(model_name)

        n_embd = 768 if 'base' in model_name else 1024
        self.cls = nn.Linear(n_embd, output_dim, bias=True)
        self.use_cls = use_cls
        self.use_similarity = use_similarity
        self.train_generator = train_generator
        self.dropout = nn.Dropout(dropout)
        self.k = k

        self.use_kld_loss = use_kld_loss
        self.lm_loss_coeff = lm_loss_coeff
        self.anneal_kld_loss = anneal_kld_loss
        self.sim_coeff = sim_coeff
        self.use_repetition_loss = use_repetition_loss
        self.rep_coeff = rep_coeff
        self.anneal_repetition_loss = anneal_repetition_loss
        self.sequence_ngram_n = sequence_ngram_n

        if freeze_embeddings:
            self.gen_embeddings_weight.requires_grad = False

        if not train_generator:
            self.gen_embeddings_weight.requires_grad = False
            generate_until_dot = True

        self.temperature = temperature
        self.train_with_regular_softmax = train_with_regular_softmax
        self.use_straight_through_gumbel_softmax = use_straight_through_gumbel_softmax
        self.anneal_temperature = anneal_temperature
        self.topk_gs = output_several_results_on_every_step
        self.results_each_step = results_each_step

        self.generate_until_dot = generate_until_dot
        self.pass_only_generated = pass_only_generated

        self.train_with_just_sim_loss_for_epochs_num = train_with_just_sim_loss_for_epochs_num
        self.add_cls_after_epoch_num = add_cls_after_epoch_num
        self.use_similarity_btw_question_and_answers = use_similarity_btw_question_and_answers
        self.decouple_gen_and_cls_embs = decouple_gen_and_cls_embs
        self.pass_probabilities_to_classifier = pass_probabilities_to_classifier
        self.zero_generated_out = zero_generated_out
        self.supervised_generator = train_lm_generator
        self.gen_lm_loss_coeff = gen_lm_loss_coeff
        self.train_cls_without_sup_gen = train_cls_without_lm_loss

        if load_weights:

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "sim_accuracy": CategoricalAccuracy(),
            "kld_loss": Average(),
            "repetition_loss": Average(),
            "classification_loss": Average(),
            "similarity_loss": Average(),
    def __init__(
            vocab: Vocabulary,
            span_encoder: Seq2SeqEncoder,
            reasoning_encoder: Seq2SeqEncoder,
            input_dropout: float = 0.3,
            hidden_dim_maxpool: int = 1024,
            class_embs: bool = True,
            reasoning_use_obj: bool = True,
            reasoning_use_answer: bool = True,
            reasoning_use_question: bool = True,
            pool_reasoning: bool = True,
            pool_answer: bool = True,
            pool_question: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
        super(MultiHopAttentionQAFreezeDetRes101, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,

        # freeze everything related to conv net
        for submodule in self.detector.backbone.modules():
            # if isinstance(submodule, BatchNorm2d):
            # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        for submodule in self.detector.after_roi_align.modules():
            # if isinstance(submodule, BatchNorm2d):
            # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        self.rnn_input_dropout = TimeDistributed(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)

        self.span_attention = BilinearMatrixAttention(

        self.obj_attention = BilinearMatrixAttention(

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
class BertMCQParallel(Model):
       get the pooled output from ph batches of size b,size
       flatten the links tokens and pool the output
       reshape to b,n,n-1,size
       reduce to b,n,size (max pool)
       reduce further to b,size (avg, max pool)
       combine two add linear layer to compute the score

    def ff(self, input_dim, hidden_dim, output_dim):
        return torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim),
                                   torch.nn.Linear(hidden_dim, output_dim))

    def __init__(self,
                 vocab: Vocabulary,
                 bert_model: Union[str, BertModel],
                 dropout: float = 0.0,
                 trainable: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:

        if isinstance(bert_model, str):
            self.bert_model = PretrainedBertModel.load(bert_model)
            self.bert_model = bert_model

        self.bert_model.requires_grad = trainable
        in_features = self.bert_model.config.hidden_size

        self._dropout = torch.nn.Dropout(p=dropout)

        self._classification_layer = torch.nn.Linear(in_features, 1)

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

    def forward(self, # type: ignore
                tokens: Dict[str, torch.LongTensor],
                token_type_ids: torch.LongTensor,
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
                ) -> Dict[str, torch.Tensor]:

        debug = False
        # batch_size, num_of_choices, max_premise_perchoice, L
        input_ids = tokens['tokens']
        # batch_size, L
        input_mask = (input_ids != 0).long()

        # shape: batch_size*num_choices*max_premise_perchoice, max_len
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        flat_attention_mask = input_mask.view(-1, input_mask.size(-1))

        # shape: batch_size*num_choices*max_premise_perchoice, hidden_dim
        _, pooled_ph = self.bert_model(input_ids=flat_input_ids,

        if debug:
            print(f"input_ids.size() = {input_ids.size()}")
            print(f"token_type_ids.size() = {token_type_ids.size()}")
            print(f"pooled_ph.size() = {pooled_ph.size()}")

        # batch*choice, max_premise_per_choice, hidden_dim
        pooled_ph = pooled_ph.view(-1,input_ids.size(2),pooled_ph.size(-1))

        max_pooled_ph,_ = torch.max(pooled_ph,dim=1,keepdim=False)

        if debug:
            print(f"max_pooled_ph.size() = {max_pooled_ph.size()}")

            max_pooled_ph = self._dropout(max_pooled_ph)

        # apply classification layer
        logits = self._classification_layer(max_pooled_ph)

        # shape: batch_size,num_choices
        reshaped_logits = logits.view(-1, input_ids.size(1))
        if debug:
            print(f"reshaped_logits = {reshaped_logits}")

        probs = torch.nn.functional.softmax(reshaped_logits, dim=-1)

        output_dict = {"logits": reshaped_logits, "probs": probs}

        if label is not None:
            loss = self._loss(reshaped_logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(reshaped_logits, label)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'accuracy': self._accuracy.get_metric(reset),
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        similarity_function: SimilarityFunction,
        modeling_layer: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        dropout: float = 0.2,
        mask_lstms: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
            4 * encoding_dim,
            "modeling layer input dim",
            "4 * encoding dim",
            "text field embedder output dim",
            "phrase layer input dim",
            4 * encoding_dim + 3 * modeling_dim,
            "span end encoder input dim",
            "4 * encoding dim + 3 * modeling dim",

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

class SpanConstituencyParser(Model):
    This ``SpanConstituencyParser`` simply encodes a sequence of text
    with a stacked ``Seq2SeqEncoder``, extracts span representations using a
    ``SpanExtractor``, and then predicts a label for each span in the sequence.
    These labels are non-terminal nodes in a constituency parse tree, which we then
    greedily reconstruct.

    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    span_extractor : ``SpanExtractor``, required.
        The method used to extract the spans from the encoded sequence.
    encoder : ``Seq2SeqEncoder``, required.
        The encoder that we will use in between embedding tokens and
        generating span representations.
    feedforward_layer : ``FeedForward``, required.
        The FeedForward layer that we will use in between the encoder and the linear
        projection to a distribution over span labels.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 span_extractor: SpanExtractor,
                 encoder: Seq2SeqEncoder,
                 feedforward_layer: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 evalb_directory_path: str = None) -> None:
        super(SpanConstituencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.span_extractor = span_extractor
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder
        self.feedforward_layer = TimeDistributed(feedforward_layer) if feedforward_layer else None
        self.pos_tag_embedding = pos_tag_embedding or None
        if feedforward_layer is not None:
            output_dim = feedforward_layer.get_output_dim()
            output_dim = span_extractor.get_output_dim()

        self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_classes))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
                               "representation dim (tokens + optional POS tags)",
                               "encoder input dim")
                               "encoder input dim",
                               "span extractor input dim")
        if feedforward_layer is not None:
                                   "span extractor output dim",
                                   "feedforward input dim")

        self.tag_accuracy = CategoricalAccuracy()

        if evalb_directory_path is not None:
            self._evalb_score = EvalbBracketingScorer(evalb_directory_path)
            self._evalb_score = None

    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input =[embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        Constructs an NLTK ``Tree`` given the scored spans. We also switch to exclusive
        span ends when constructing the tree representation, because it makes indexing
        into lists cleaner for ranges of text, rather than individual indices.

        Finally, for batch prediction, we will have padded spans and class probabilities.
        In order to make this less confusing, we remove all the padded spans and
        distributions from ``spans`` and ``class_probabilities`` respectively.
        all_predictions = output_dict['class_probabilities'].cpu().data
        all_spans = output_dict["spans"].cpu().data

        all_sentences = output_dict["tokens"]
        all_pos_tags = output_dict["pos_tags"] if all(output_dict["pos_tags"]) else None
        num_spans = output_dict["num_spans"].data
        trees = self.construct_trees(all_predictions, all_spans, num_spans, all_sentences, all_pos_tags)

        batch_size = all_predictions.size(0)
        output_dict["spans"] = [all_spans[i, :num_spans[i]] for i in range(batch_size)]
        output_dict["class_probabilities"] = [all_predictions[i, :num_spans[i], :] for i in range(batch_size)]

        output_dict["trees"] = trees
        return output_dict

    def construct_trees(self,
                        predictions: torch.FloatTensor,
                        all_spans: torch.LongTensor,
                        num_spans: torch.LongTensor,
                        sentences: List[List[str]],
                        pos_tags: List[List[str]] = None) -> List[Tree]:
        Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
        The trees use exclusive end indices, which contrasts with how spans are
        represented in the rest of the model.

        predictions : ``torch.FloatTensor``, required.
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        all_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the span
            indices we scored.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        sentences : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, optional (default = None).
            A list of POS tags for each word in the sentence for each element
            in the batch.

        A ``List[Tree]`` containing the decoded trees for each element in the batch.
        # Switch to using exclusive end spans.
        exclusive_end_spans = all_spans.clone()
        exclusive_end_spans[:, :, -1] += 1
        no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

        trees: List[Tree] = []
        for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions,
            selected_spans = []
            for prediction, span in zip(scored_spans[:num_spans[batch_index]],
                start, end = span
                no_label_prob = prediction[no_label_id]
                label_prob, label_index = torch.max(prediction, -1)

                # Does the span have a label != NO-LABEL or is it the root node?
                # If so, include it in the spans that we consider.
                if int(label_index) != no_label_id or (start == 0 and end == len(sentence)):
                    # TODO(Mark): Remove this once pylint sorts out named tuples.
                    selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter

            # The spans we've selected might overlap, which causes problems when we try
            # to construct the tree as they won't nest properly.
            consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans)

            spans_to_labels = {(span.start, span.end):
                               self.vocab.get_token_from_index(span.label_index, "labels")
                               for span in consistent_spans}
            sentence_pos = pos_tags[batch_index] if pos_tags is not None else None
            trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos))

        return trees

    def resolve_overlap_conflicts_greedily(spans: List[SpanInformation]) -> List[SpanInformation]:
        Given a set of spans, removes spans which overlap by evaluating the difference
        in probability between one being labeled and the other explicitly having no label
        and vice-versa. The worst case time complexity of this method is ``O(k * n^4)`` where ``n``
        is the length of the sentence that the spans were enumerated from (and therefore
        ``k * m^2`` complexity with respect to the number of spans ``m``) and ``k`` is the
        number of conflicts. However, in practice, there are very few conflicts. Hopefully.

        This function modifies ``spans`` to remove overlapping spans.

        spans: ``List[SpanInformation]``, required.
            A list of spans, where each span is a ``namedtuple`` containing the
            following attributes:

        start : ``int``
            The start index of the span.
        end : ``int``
            The exclusive end index of the span.
        no_label_prob : ``float``
            The probability of this span being assigned the ``NO-LABEL`` label.
        label_prob : ``float``
            The probability of the most likely label.

        A modified list of ``spans``, with the conflicts resolved by considering local
        differences between pairs of spans and removing one of the two spans.
        conflicts_exist = True
        while conflicts_exist:
            conflicts_exist = False
            for span1_index, span1 in enumerate(spans):
                for span2_index, span2 in list(enumerate(spans))[span1_index + 1:]:
                    if (span1.start < span2.start < span1.end < span2.end or
                                span2.start < span1.start < span2.end < span1.end):
                        # The spans overlap.
                        conflicts_exist = True
                        # What's the more likely situation: that span2 was labeled
                        # and span1 was unlabled, or that span1 was labeled and span2
                        # was unlabled? In the first case, we delete span2 from the
                        # set of spans to form the tree - in the second case, we delete
                        # span1.
                        if (span1.no_label_prob + span2.label_prob <
                                    span2.no_label_prob + span1.label_prob):
        return spans

    def construct_tree_from_spans(spans_to_labels: Dict[Tuple[int, int], str],
                                  sentence: List[str],
                                  pos_tags: List[str] = None) -> Tree:
        spans_to_labels : ``Dict[Tuple[int, int], str]``, required.
            A mapping from spans to constituency labels.
        sentence : ``List[str]``, required.
            A list of tokens forming the sentence to be parsed.
        pos_tags : ``List[str]``, optional (default = None)
            A list of the pos tags for the words in the sentence, if they
            were either predicted or taken as input to the model.

        An ``nltk.Tree`` constructed from the labelled spans.
        def assemble_subtree(start: int, end: int):
            if (start, end) in spans_to_labels:
                # Some labels contain nested spans, e.g S-VP.
                # We actually want to create (S (VP ...)) nodes
                # for these labels, so we split them up here.
                labels: List[str] = spans_to_labels[(start, end)].split("-")
                labels = None

            # This node is a leaf.
            if end - start == 1:
                word = sentence[start]
                pos_tag = pos_tags[start] if pos_tags is not None else "XX"
                tree = Tree(pos_tag, [word])
                if labels is not None and pos_tags is not None:
                    # If POS tags were passed explicitly,
                    # they are added as pre-terminal nodes.
                    while labels:
                        tree = Tree(labels.pop(), [tree])
                elif labels is not None:
                    # Otherwise, we didn't want POS tags
                    # at all.
                    tree = Tree(labels.pop(), [word])
                    while labels:
                        tree = Tree(labels.pop(), [tree])
                return [tree]

            argmax_split = start + 1
            # Find the next largest subspan such that
            # the left hand side is a constituent.
            for split in range(end - 1, start, -1):
                if (start, split) in spans_to_labels:
                    argmax_split = split

            left_trees = assemble_subtree(start, argmax_split)
            right_trees = assemble_subtree(argmax_split, end)
            children = left_trees + right_trees
            if labels is not None:
                while labels:
                    children = [Tree(labels.pop(), children)]
            return children

        tree = assemble_subtree(0, len(sentence))
        return tree[0]

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        all_metrics["tag_accuracy"] = self.tag_accuracy.get_metric(reset=reset)
        if self._evalb_score is not None:
            evalb_metrics = self._evalb_score.get_metric(reset=reset)
        return all_metrics

    def from_params(cls, vocab: Vocabulary, params: Params) -> 'SpanConstituencyParser':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        span_extractor = SpanExtractor.from_params(params.pop("span_extractor"))
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))

        feed_forward_params = params.pop("feedforward", None)
        if feed_forward_params is not None:
            feedforward_layer = FeedForward.from_params(feed_forward_params)
            feedforward_layer = None
        pos_tag_embedding_params = params.pop("pos_tag_embedding", None)
        if pos_tag_embedding_params is not None:
            pos_tag_embedding = Embedding.from_params(vocab, pos_tag_embedding_params)
            pos_tag_embedding = None
        initializer = InitializerApplicator.from_params(params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))
        evalb_directory_path = params.pop("evalb_directory_path", None)

        return cls(vocab=vocab,
class AclGloveClassifier(Model):
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 node_embedder: TokenEmbedder,
                 verbose_metrics: False,
                 classifier_feedforward: FeedForward,
                 use_node_vector: bool = True,
                 use_abstract: bool = True,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(AclGloveClassifier, self).__init__(vocab, regularizer)

        self.node_embedder = node_embedder
        self.text_field_embedder = text_field_embedder
        self.use_node_vector = use_node_vector
        self.use_abstract = use_abstract
        self.dropout = torch.nn.Dropout(dropout)
        self.num_classes = self.vocab.get_vocab_size("labels")

        self.classifier_feedforward = classifier_feedforward

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            label_name = vocab.get_token_from_index(index=i, namespace="labels")
            self.label_f1_metrics[label_name] = F1Measure(positive_label=i)

        self.loss = torch.nn.CrossEntropyLoss()


    def forward(self,
                abstract: Dict[str, torch.LongTensor],
                paper_id: torch.LongTensor,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        embedding = torch.mean(self.text_field_embedder(abstract), dim=1)
        if  self.use_node_vector:
            embedding =[
                self.node_embedder(paper_id)], dim=-1)

        logits = self.classifier_feedforward(self.dropout(embedding))
        class_probs = F.softmax(logits, dim=1)
        output_dict = {"logits": logits}

        if label is not None:
            loss = self.loss(logits, label)
            output_dict["label"] = label
            output_dict["loss"] = loss
        for i in range(self.num_classes):
            label_name = self.vocab.get_token_from_index(index=i, namespace="labels")
            metric = self.label_f1_metrics[label_name]
            metric(class_probs, label)
        self.label_accuracy(logits, label)

        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        class_probs = F.softmax(output_dict['logits'], dim=-1)
        output_dict['pred_label'] = [
            self.vocab.get_token_from_index(index=int(np.argmax(probs)), namespace="labels")
            for probs in class_probs.cpu()
        output_dict['label'] = [
            self.vocab.get_token_from_index(index=int(label), namespace="labels")
            for label in output_dict['label'].cpu()
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}

        sum_f1 = 0.0
        for name, metric in self.label_f1_metrics.items():
            metric_val = metric.get_metric(reset)
            if self.verbose_metrics:
                metric_dict[name + '_P'] = metric_val[0]
                metric_dict[name + '_R'] = metric_val[1]
                metric_dict[name + '_F1'] = metric_val[2]
            sum_f1 += metric_val[2]

        names = list(self.label_f1_metrics.keys())
        total_len = len(names)
        if total_len > 0:
            average_f1 = sum_f1 / total_len
            average_f1 = 0.0
        metric_dict['average_F1'] = average_f1
        metric_dict['accuracy'] = self.label_accuracy.get_metric(reset)
        return metric_dict
 def test_does_not_divide_by_zero_with_no_count(self):
     accuracy = CategoricalAccuracy()
     self.assertAlmostEqual(accuracy.get_metric(), 0.0)
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 question_encoder: Optional[Seq2SeqEncoder] = None,
                 choice_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 aggregate_question: Optional[str] = "max",
                 aggregate_choice: Optional[str] = "max",
                 embeddings_dropout_value: Optional[float] = 0.0,
                 share_encoders: Optional[bool] = False,
                 choices_init_from_question_states: Optional[bool] = False,
                 use_choice_sum_instead_of_question: Optional[bool] = False,
                 params=Params) -> None:
        super(QAMultiChoice_OneVsRest_Choices_v1, self).__init__(vocab)

        # TO DO: AllenNLP does not support statefull RNNS yet..
        init_is_supported = False
        if not init_is_supported and (choices_init_from_question_states):
            raise ValueError(
                "choices_init_from_question_states=True or facts_init_from_question_states=True are not supported yet!")
            self._choices_init_from_question_states = choices_init_from_question_states

        self._use_cuda = (torch.cuda.is_available() and torch.cuda.current_device() >= 0)

        self._return_question_to_choices_att = False
        self._use_choice_sum_instead_of_question = use_choice_sum_instead_of_question

        self._params = params

        self._text_field_embedder = text_field_embedder
        if embeddings_dropout_value > 0.0:
            self._embeddings_dropout = torch.nn.Dropout(p=embeddings_dropout_value)
            self._embeddings_dropout = lambda x: x

        self._question_encoder = question_encoder

        # choices encoding
        self._choice_encoder = choice_encoder

        self._question_aggregate = aggregate_question
        self._choice_aggregate = aggregate_choice

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        question_output_dim = self._text_field_embedder.get_output_dim()
        if self._question_encoder is not None:
            question_output_dim = self._question_encoder.get_output_dim()

        choice_output_dim = self._text_field_embedder.get_output_dim()
        if self._choice_encoder is not None:
            choice_output_dim = self._choice_encoder.get_output_dim()

        if question_output_dim != choice_output_dim:
            raise ConfigurationError("Output dimension of the question_encoder (dim: {}), "
                                     "plus choice_encoder (dim: {})"
                                     "must match! "

        # question to choice attention
        att_question_to_choice_params = params.get("att_question_to_choice")
        if "tensor_1_dim" in att_question_to_choice_params:
            att_question_to_choice_params = update_params(att_question_to_choice_params,
                                                          {"tensor_1_dim": question_output_dim,
                                                           "tensor_2_dim": choice_output_dim})
        self._matrix_attention_question_to_choice = LegacyMatrixAttention(

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

span_end_logits = span_end_predictor(span_end_input).squeeze(-1)
span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

print ("-------------- LOGITS OF BOTH SPANS and BEST SPAN ---------------")

span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)

best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits)

print ("best_spans", best_span)

------------------------------ GET LOSES AND ACCURACIES -----------------------------------
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
class QAMultiChoice_OneVsRest_Choices_v1(Model):
   This ``QAMultiChoice_OneVsRest_Choices_v1`` can have different modes:

   1. Question to Choice:
        If `use_choice_sum_instead_of_question`==False then this is a classifcal BiLSTM maxout model that models the
        interaction between Question and Choice representation.
   2. Choite-To-Choice: If use_choice_sum_instead_of_question==True.
        In this case the `question` representation is replaced with the average of all choices representations.
   3. Choice only: If the function "att_question_to_choice" has setting
      "type": "linear_extended",
      "combination": "y",
      The `combination` for the interaction is set to `y` which means that only the choice is used to predict the answer.

   In more details teh models work in the following way:
    1. Obtain a BiLSTM context representation of the token sequences of the
    `question` and each `choice`.
    2. Get an aggregated (single vector) representations for `question` and `choice` using element-wise `max` operation.
        If use_choice_sum_instead_of_question== True, then `question` == avg(`choice1`, `choice2`.. `choiceN`)
    3. Compute the attention score between `question` and `choice` as  `linear_layer([u, v, u - v, u * v])`,
    where `u` and `v` are the representations from Step 2.
        Here, we can change the attention function from `linear_layer([u, v, u - v, u * v])` to simply `linear_layer(y)`,
        which means that we will use only the choice for the final prediction!
    4. Select as answer the `choice` with the highest attention with the `question`.

    Pseudo-code looks like:

    question_encoded = context_enc(question_words)  # context_enc can be any AllenNLP supported or None. Bi-directional LSTM is used
    choice_encoded = context_enc(choice_words)

    question_aggregate = aggregate_method(question_encoded) # aggregate_method can be max, min, avg. ``max`` is used.
    choice_aggregate = aggregate_method(choice_encoded)

    If use_choice_sum_instead_of_question==True:
        # In this case we have choice-to-chocies interaction
        question_aggregate = (choice1_aggregate + choice2_aggregate + choice3_aggregate + choice4_aggregate) / 4

    If att_question_to_choice.combination=="x,y,x-y,x*y":
        inter = concat([question_aggregate, choice_aggregate, choice_aggregate - question_aggregate, question_aggregate
         * choice_aggregate)
    elif att_question_to_choice.combination=="y":
        # In this case we have choice-only interaction
        inter = choice_aggregate

    choice_to_question_att = linear_layer(inter) # the output is a scalar value for each question-to-choice interaction

    # The choice_to_question_att of the four choices are normalized using ``softmax``
    # and the choice with the highest attention is selected as the answer.

    The model is inspired by the BiLSTM Max-Out model from Conneau, A. et al. (2017) ‘Supervised Learning of
    Universal Sentence Representations from Natural Language Inference Data’.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``choice`` ``TextFields`` we get as input to the
    aggregate_feedforward : ``FeedForward``
        These feedforward networks are applied to the concatenated result of the
        encoder networks, and its output is used as the entailment class logits.
    question_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the question, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    choice_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the choice, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``question_encoder`` for the encoding (doing nothing if ``question_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    share_encoders : ``bool``, optional (default=``false``)
        Shares the weights of the question and choice encoders.
    aggregate_question : ``str``, optional (default=``max``, allowed options [max, avg, sum, last])
        The aggregation method for the encoded question.
    aggregate_choice : ``str``, optional (default=``max``, allowed options [max, avg, sum, last])
        The aggregation method for the encoded choice.

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 question_encoder: Optional[Seq2SeqEncoder] = None,
                 choice_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 aggregate_question: Optional[str] = "max",
                 aggregate_choice: Optional[str] = "max",
                 embeddings_dropout_value: Optional[float] = 0.0,
                 share_encoders: Optional[bool] = False,
                 choices_init_from_question_states: Optional[bool] = False,
                 use_choice_sum_instead_of_question: Optional[bool] = False,
                 params=Params) -> None:
        super(QAMultiChoice_OneVsRest_Choices_v1, self).__init__(vocab)

        # TO DO: AllenNLP does not support statefull RNNS yet..
        init_is_supported = False
        if not init_is_supported and (choices_init_from_question_states):
            raise ValueError(
                "choices_init_from_question_states=True or facts_init_from_question_states=True are not supported yet!")
            self._choices_init_from_question_states = choices_init_from_question_states

        self._use_cuda = (torch.cuda.is_available() and torch.cuda.current_device() >= 0)

        self._return_question_to_choices_att = False
        self._use_choice_sum_instead_of_question = use_choice_sum_instead_of_question

        self._params = params

        self._text_field_embedder = text_field_embedder
        if embeddings_dropout_value > 0.0:
            self._embeddings_dropout = torch.nn.Dropout(p=embeddings_dropout_value)
            self._embeddings_dropout = lambda x: x

        self._question_encoder = question_encoder

        # choices encoding
        self._choice_encoder = choice_encoder

        self._question_aggregate = aggregate_question
        self._choice_aggregate = aggregate_choice

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        question_output_dim = self._text_field_embedder.get_output_dim()
        if self._question_encoder is not None:
            question_output_dim = self._question_encoder.get_output_dim()

        choice_output_dim = self._text_field_embedder.get_output_dim()
        if self._choice_encoder is not None:
            choice_output_dim = self._choice_encoder.get_output_dim()

        if question_output_dim != choice_output_dim:
            raise ConfigurationError("Output dimension of the question_encoder (dim: {}), "
                                     "plus choice_encoder (dim: {})"
                                     "must match! "

        # question to choice attention
        att_question_to_choice_params = params.get("att_question_to_choice")
        if "tensor_1_dim" in att_question_to_choice_params:
            att_question_to_choice_params = update_params(att_question_to_choice_params,
                                                          {"tensor_1_dim": question_output_dim,
                                                           "tensor_2_dim": choice_output_dim})
        self._matrix_attention_question_to_choice = LegacyMatrixAttention(

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()


    # propeties

    def return_question_to_choices_att(self, ):
        return self._return_question_to_choices_att

    def return_question_to_choices_att(self, value: bool):
        This makes the model to return question to choice attentions
        :return: nothing

        self._return_question_to_choices_att = value

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                choices_list: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        encoded_choices_aggregated = embedd_encode_and_aggregate_list_text_field(choices_list,
                                                                                 self._choice_aggregate)  # # bs, choices, hs

        if not self._use_choice_sum_instead_of_question:
            encoded_question_aggregated, _ = embedd_encode_and_aggregate_text_field(question, self._text_field_embedder,
                                                                                    get_last_states=False)  # bs, hs

            q_to_choices_att = self._matrix_attention_question_to_choice(encoded_question_aggregated.unsqueeze(1),

            label_logits = q_to_choices_att
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
            bs = encoded_choices_aggregated.shape[0]
            choices_cnt = encoded_choices_aggregated.shape[1]

            ch_to_choices_att = self._matrix_attention_question_to_choice(encoded_choices_aggregated,
                                                                          encoded_choices_aggregated)  # bs, ch, ch

            idx = torch.arange(0, choices_cnt, out=torch.cuda.LongTensor() if self._use_cuda else torch.LongTensor())
            ch_to_choices_att[:, idx, idx] = 0.0

            q_to_choices_att = torch.sum(ch_to_choices_att, dim=1)

            label_logits = q_to_choices_att
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if self._return_question_to_choices_att:
            attentions_dict = {}
            know_interactions_weights_dict = {}
            if self._return_question_to_choices_att:
                # Keep also the interaction weights used for the final prediction

                # attentions
                att_to_export_q_to_ch = {}
                q_to_ch_raw_type = "__".join(["ctx", "ctx"])

                if q_to_ch_raw_type not in know_interactions_weights_dict:
                    know_interactions_weights_dict[q_to_ch_raw_type] = 1.0

                if not q_to_ch_raw_type in att_to_export_q_to_ch:
                    q_to_ch_att_ctx_ctx = self._matrix_attention_question_to_choice(
                    q_to_ch_att_ctx_ctx = torch.nn.functional.softmax(q_to_ch_att_ctx_ctx, dim=-1)
                    att_to_export_q_to_ch[q_to_ch_raw_type] =

                att_to_export_q_to_ch["final"] =
                attentions_dict["att_q_to_ch"] = att_to_export_q_to_ch

            output_dict["attentions"] = attentions_dict

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'accuracy': self._accuracy.get_metric(reset),

    def from_params(cls, vocab: Vocabulary, params: Params) -> 'QAMultiChoice_OneVsRest_Choices_v1':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = BasicTextFieldEmbedder.from_params(vocab, embedder_params)

        embeddings_dropout_value = params.pop("embeddings_dropout", 0.0)

        # question encoder
        question_encoder_params = params.pop("question_encoder", None)
        question_enc_aggregate = params.pop("question_encoder_aggregate", "max")
        share_encoders = params.pop("share_encoders", False)

        # condition the choices or facts encoding on quesiton output states
        choices_init_from_question_states = params.pop("choices_init_from_question_states", False)

        if question_encoder_params is not None:
            question_encoder = Seq2SeqEncoder.from_params(question_encoder_params)
            question_encoder = None

        if share_encoders:
            choice_encoder = question_encoder
            choice_enc_aggregate = question_enc_aggregate
            # choice encoder
            choice_encoder_params = params.pop("choice_encoder", None)
            choice_enc_aggregate = params.pop("choice_encoder_aggregate", "max")

            if choice_encoder_params is not None:
                choice_encoder = Seq2SeqEncoder.from_params(choice_encoder_params)
                choice_encoder = None

        use_choice_sum_instead_of_question = params.get("use_choice_sum_instead_of_question", False)
        init_params = params.pop('initializer', None)
        initializer = (InitializerApplicator.from_params(init_params)
                       if init_params is not None
                       else InitializerApplicator())

        return cls(vocab=vocab,
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
   += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
            # Rest of init not needed for denotation only where no decoding to actions needed

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
class BidirectionalAttentionFlow(Model):
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms


    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage =[encoded_passage,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation =[final_merged_passage,
                                             modeled_passage * tiled_start_representation],
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,

    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")


        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)
class ESIM(Model):
    This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
    by Chen et al., 2017.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
    encoder : ``Seq2SeqEncoder``
        Used to encode the premise and hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between encoded
        words in the premise and words in the hypothesis.
    projection_feedforward : ``FeedForward``
        The feedforward network used to project down the encoded and enhanced premise and hypothesis.
    inference_encoder : ``Seq2SeqEncoder``
        Used to encode the projected premise and hypothesis for prediction.
    output_feedforward : ``FeedForward``
        Used to prepare the concatenated premise and hypothesis for prediction.
    output_logit : ``FeedForward``
        This feedforward network computes the output logits.
    dropout : ``float``, optional (default=0.5)
        Dropout percentage to use.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 dropout: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(),
                               "encoder output dim", "projection feedforward input")
        check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(),
                               "proj feedforward output dim", "inference lstm input dim")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()


    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced =
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
        hypothesis_enhanced =
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all =[v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 mask_lstms: bool = True) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these
        # aren't necessarily obvious from the configuration files, so we check
        # here.
        if modeling_layer.get_input_dim() != 4 * encoding_dim:
            raise ConfigurationError(
                "The input dimension to the modeling_layer must be "
                "equal to 4 times the encoding dimension of the phrase_layer. "
                "Found {} and 4 * {} respectively.".format(
                    modeling_layer.get_input_dim(), encoding_dim))
        if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim(
            raise ConfigurationError(
                "The output dimension of the text_field_embedder (embedding_dim + "
                "char_cnn) must match the input dimension of the phrase_encoder. "
                "Found {} and {}, respectively.".format(

        if span_end_encoder.get_input_dim(
        ) != encoding_dim * 4 + modeling_dim * 3:
            raise ConfigurationError(
                "The input dimension of the span_end_encoder should be equal to "
                "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. "
                "Found {} and (4 * {} + 3 * {}) "
                                       encoding_dim, modeling_dim))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._official_em = Average()
        self._official_f1 = Average()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms
class BidirectionalAttentionFlow_1(Model):
    This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model
    for answering reading comprehension questions (ICLR 2017).
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        Initialize some data structures 
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
            text_field_embedder = None
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
            self._linear_projection_ELMO = linear_projection_ELMO
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
            Input_dimension_highway = text_field_embedder.get_output_dim()
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            highway_layer = Highway(Input_dimension_highway,
        highway_layer = TimeDistributed(highway_layer)
        self._highway_layer = highway_layer
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
            dropout_phrase_layer = lambda x: x
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
            dropout_modeling_passage = lambda x: x
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        ############## Span End Representation #####################
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
            dropout_span_end_encode = lambda x: x
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        Dropput last layers
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
            dropout_spans_output = lambda x: x
        self._dropout_spans_output = dropout_spans_output
        Checkings and accuracy
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        self._mask_lstms = cf_a.mask_lstms

        ################### Initialize parameters ##############################
        ####################### OPTIMIZER ################
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
        #### TODO: Learning rate scheduler ####
        #scheduler = optim.ReduceLROnPlateau(optimizer, 'max')
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        Sample 10 times and add them together
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        subresults = [most_likely_output]
        for i in range(10):

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
        return output
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False,
                get_attentions = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        #################### Sample Bayesian weights ##################
        ################## MASK COMPUTING ########################
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        ###################### EMBEDDING + HIGHWAY LAYER ########################
#        self.cf_a.use_ELMO
            embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]))
            embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        ###################### phrase_layer LAYER ########################

        encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        ###################### Attention LAYER ########################
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage =[encoded_passage,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],

        modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        ###################### Spans LAYER ########################
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation =[final_merged_passage,
                                             modeled_passage * tiled_start_representation],
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation,
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,

        # Compute the loss for training.
        if span_start is not None:
            span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
    def train_batch(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        It is enough to just compute the total loss because the normal weights 
        do not depend on the KL Divergence
        # Now we can just compute both losses which will build the dynamic graph
        output = self.forward(question,passage,span_start,span_end,metadata )
        data_loss = output["loss"]
        KL_div = self.get_KL_divergence()
        total_loss =  self.combine_losses(data_loss, KL_div)
        self.zero_grad()     # zeroes the gradient buffers of all parameters
        if (type(self._optimizer) == type(None)):
            parameters = filter(lambda p: p.requires_grad, self.parameters())
            with torch.no_grad():
                for f in parameters:
           * )
#            print ("Training")
        return output
    def fill_batch_training_information(self, training_logger,
        Function to fill the the training_logger for each batch. 
        training_logger: Dictionary that will hold all the training info
        output_batch: Output from training the batch
        # Training metrics:
        metrics = self.get_metrics()
    def fill_epoch_training_information(self, training_logger,device,
                                        validation_iterable, num_batches_validation):
        Fill the information per each epoch
        Ntrials_CUDA = 100
        # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        data_loss_validation = 0
        loss_validation = 0
        with torch.no_grad():
            # Compute the validation accuracy by using all the Validation dataset but in batches.
            for j in range(num_batches_validation):
                tensor_dict = next(validation_iterable)
                trial_index = 0
                while (1):
                        tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda
                        output_batch = self.forward(**tensor_dict)
                    except RuntimeError as er:
                        print (er.args)
                        trial_index += 1
                        if (trial_index == Ntrials_CUDA):
                            print ("Too many failed trials to allocate in memory")
                data_loss_validation += output_batch["loss"].detach().cpu().numpy() 
                ## Memmory management !!
            if (self.cf_a.force_free_batch_memory):
                del tensor_dict["question"]; del tensor_dict["passage"]
                del tensor_dict
                del output_batch
            if (self.cf_a.force_call_garbage_collector):
            data_loss_validation = data_loss_validation/num_batches_validation
#            loss_validation = loss_validation/num_batches_validation
            # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
    def trim_model(self, mu_sigma_ratio = 2):
        total_size_w = []
        total_removed_w = []
        total_size_b = []
        total_removed_b = []
        if (self.cf_a.VB_Linear_projection_ELMO):
                VBmodel = self._linear_projection_ELMO
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
        if (self.cf_a.VB_highway_layers):
                VBmodel = self._highway_layer._module.VBmodels[0]
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
        if (self.cf_a.VB_similarity_function):
                VBmodel = self._matrix_attention._similarity_function
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
        if (self.cf_a.VB_span_start_predictor_linear):
                VBmodel = self._span_start_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
        if (self.cf_a.VB_span_end_predictor_linear):
                VBmodel = self._span_end_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
        return  total_size_w, total_removed_w, total_size_b, total_removed_b
#    print (weights_to_remove_W.shape)

    sample_posterior = GeneralVBModel.sample_posterior
    get_KL_divergence = GeneralVBModel.get_KL_divergence
    set_posterior_mean = GeneralVBModel.set_posterior_mean
    combine_losses = GeneralVBModel.combine_losses
    def save_VB_weights(self):
        Function that saves only the VB weights of the model.
        pretrained_dict = ...
        model_dict = self.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        # 3. load the new state dict
Exemple #44
class BidirectionalAttentionFlow(Model):
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    To instantiate this model with parameters matching those in the original paper, simply use
    ``BidirectionalAttentionFlow.from_params(vocab, Params({}))``.  This will construct all of the
    various dependencies needed for the constructor for you.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    attention_similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    initializer : ``InitializerApplicator``
        We will use this to initialize the parameters in the model, calling ``initializer(self)``.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    evaluation_json_file : ``str``, optional
        If given, we will load this JSON into memory and use it to compute official metrics
        against.  We need this separately from the validation dataset, because the official metrics
        use all of the annotations, while our dataset reader picks the most frequent one.
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 mask_lstms: bool = True) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these
        # aren't necessarily obvious from the configuration files, so we check
        # here.
        if modeling_layer.get_input_dim() != 4 * encoding_dim:
            raise ConfigurationError(
                "The input dimension to the modeling_layer must be "
                "equal to 4 times the encoding dimension of the phrase_layer. "
                "Found {} and 4 * {} respectively.".format(
                    modeling_layer.get_input_dim(), encoding_dim))
        if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim(
            raise ConfigurationError(
                "The output dimension of the text_field_embedder (embedding_dim + "
                "char_cnn) must match the input dimension of the phrase_encoder. "
                "Found {} and {}, respectively.".format(

        if span_end_encoder.get_input_dim(
        ) != encoding_dim * 4 + modeling_dim * 3:
            raise ConfigurationError(
                "The input dimension of the span_end_encoder should be equal to "
                "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. "
                "Found {} and (4 * {} + 3 * {}) "
                                       encoding_dim, modeling_dim))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._official_em = Average()
        self._official_f1 = Average()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.  The ending position is `exclusive`, so our
            :class:`` adds a special ending token to the
            end of the passage, to allow for the last token to be included in the answer span.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `exclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (exclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        embedded_question = self._highway_layer(
        embedded_passage = self._highway_layer(
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage =[
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
  [final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation =[
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
  [final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get('answer_texts', [])
                exact_match = f1_score = 0
                if answer_texts:
                    exact_match = squad_eval.metric_max_over_ground_truths(
                        squad_eval.exact_match_score, best_span_string,
                    f1_score = squad_eval.metric_max_over_ground_truths(
                        squad_eval.f1_score, best_span_string, answer_texts)
                self._official_em(100 * exact_match)
                self._official_f1(100 * f1_score)
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': self._official_em.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),

    def _get_best_span(span_start_logits: Variable,
                       span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(
            batch_size, 2).fill_(0)).long()

        span_start_logits =
        span_end_logits =

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'BidirectionalAttentionFlow':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        num_highway_layers = params.pop("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(
        modeling_layer = Seq2SeqEncoder.from_params(
        span_end_encoder = Seq2SeqEncoder.from_params(
        initializer = InitializerApplicator.from_params(
            params.pop("initializer", []))
        dropout = params.pop('dropout', 0.2)

        # TODO: Remove the following when fully deprecated
        evaluation_json_file = params.pop('evaluation_json_file', None)
        if evaluation_json_file is not None:
                "the 'evaluation_json_file' model parameter is deprecated, please remove"

        mask_lstms = params.pop('mask_lstms', True)
        return cls(vocab=vocab,
Exemple #45
class DecomposableAttention(Model):
    def __init__(self, vocab: Vocabulary, token_embedder: TokenEmbedder,
                 num_labels: int) -> None:

        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": token_embedder})
        dim = token_embedder.get_output_dim()
        self._attend_feedforward = TimeDistributed(
            FeedForward(dim, 1, 100, torch.nn.ReLU(), 0.2))
        self._matrix_attention = DotProductMatrixAttention()
        self._compare_feedforward = TimeDistributed(
            FeedForward(dim * 2, 1, 100, torch.nn.ReLU(), 0.2))

        # linear denotes "lambda x: x"

        self._aggregate_feedforward = FeedForward(200, 1, num_labels,
                                                  PassThrough(), 0.0)

        self._num_labels = num_labels

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

    def forward(  # type: ignore
        sent1: TextFieldTensors,
        sent2: TextFieldTensors,
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        with adv_utils.forward_context("sent1"):
            embedded_sent1 = self._text_field_embedder(sent1)
        with adv_utils.forward_context("sent2"):
            embedded_sent2 = self._text_field_embedder(sent2)
        sent1_mask = get_text_field_mask(sent1)
        sent2_mask = get_text_field_mask(sent2)

        projected_sent1 = self._attend_feedforward(embedded_sent1)
        projected_sent2 = self._attend_feedforward(embedded_sent2)
        # Shape: (batch_size, sent1_length, sent2_length)
        similarity_matrix = self._matrix_attention(projected_sent1,

        # Shape: (batch_size, sent1_length, sent2_length)
        p2h_attention = masked_softmax(similarity_matrix, sent2_mask)
        # Shape: (batch_size, sent1_length, embedding_dim)
        attended_sent2 = weighted_sum(embedded_sent2, p2h_attention)

        # Shape: (batch_size, sent2_length, sent1_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), sent1_mask)
        # Shape: (batch_size, sent2_length, embedding_dim)
        attended_sent1 = weighted_sum(embedded_sent1, h2p_attention)

        sent1_compare_input =[embedded_sent1, attended_sent2],
        sent2_compare_input =[embedded_sent2, attended_sent1],

        compared_sent1 = self._compare_feedforward(sent1_compare_input)
        compared_sent1 = compared_sent1 * sent1_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_sent1 = compared_sent1.sum(dim=1)

        compared_sent2 = self._compare_feedforward(sent2_compare_input)
        compared_sent2 = compared_sent2 * sent2_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_sent2 = compared_sent2.sum(dim=1)

        aggregate_input =[compared_sent1, compared_sent2], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "logits": label_logits,
            "probs": label_probs,

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self._accuracy.get_metric(reset)}

    def get_optimizer(self):
        return DenseSparseAdam(self.named_parameters(), lr=5e-4)
        # return AdadeltaOptimizer(self.named_parameters())
class DecomposableAttention(Model):
    This ``Model`` implements the Decomposable Attention model described in `"A Decomposable
    Attention Model for Natural Language Inference"
    by Parikh et al., 2016, with some optional enhancements before the decomposable attention
    actually happens.  Parikh's original model allowed for computing an "intra-sentence" attention
    before doing the decomposable entailment step.  We generalize this to any
    :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before
    computing entailment.

    The basic outline of this model is to get an embedded representation of each word in the
    premise and hypothesis, align words between the two, compare the aligned phrases, and make a
    final entailment decision based on this aggregated comparison.  Each step in this process uses
    a feedforward network to modify the representation.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DecomposableAttention, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(),
                               "text field embedding dim", "attend feedforward input dim")
        check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels,
                               "final output dimension", "number of labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()


    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input =[embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input =[embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input =[compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'accuracy': self._accuracy.get_metric(reset),
Exemple #47
class SWAGExampleModel(Model):
    """An example model for the SWAG dataset.

    This model predicts on the SWAG task by encoding the startphrase and
    the four endings, taking their dot products and then predicting the
    most similar one.

    vocab : Vocabulary
        The vocabulary for the data.
    text_field_embedder : TextFieldEmbedder
        A module to embed the text for both the startphrase and the
    startphrase_encoder : Seq2VecEncoder
        The encoder for the startphrase.
    ending_encoder : Seq2VecEncoder
        The encoder for the endings. It will be applied to each ending
    similarity : SimilarityFunction
        The notion of similarity to use between the startphrase and the
        ending embeddings.
    initializer : InitializerApplicator
        An initializer defining how to initialize all variables.
    regularizer : RegularizerApplicator, optional (default=None)
        Regularization to apply for training.
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 startphrase_encoder: Seq2VecEncoder,
                 ending_encoder: Seq2VecEncoder,
                 similarity: SimilarityFunction,
                 initializer: InitializerApplicator,
                 regularizer: RegularizerApplicator = None) -> None:
        super().__init__(vocab, regularizer)

        # validate the configuration
                               "text field embedding dim",
                               "startphrase encoder input dim")
                               "text field embedding dim",
                               "ending encoder input dim")
                               "startphrase embedding dim",
                               "ending embedding dim")

        # bind all attributes to the instance
        self.text_field_embedder = text_field_embedder
        self.startphrase_encoder = startphrase_encoder
        self.ending_encoder = ending_encoder
        self.similarity = similarity

        # set the training and validation losses
        self.xentropy = torch.nn.CrossEntropyLoss()
        self.accuracy = CategoricalAccuracy()

        # initialize all variables

    def forward(
        startphrase: Dict[str, torch.LongTensor],
        ending0: Dict[str, torch.LongTensor],
        ending1: Dict[str, torch.LongTensor],
        ending2: Dict[str, torch.LongTensor],
        ending3: Dict[str, torch.LongTensor],
        label: Optional[torch.IntTensor] = None,
        metadata: Optional[List[Dict[str, Any]]] = None
    ) -> Dict[str, torch.Tensor]:
        """Forward pass for predicting the best ending.

        startphrase : Dict[str, torch.LongTensor]
            The startphrase field.
        ending0 : Dict[str, torch.LongTensor]
            The ending0 field.
        ending1 : Dict[str, torch.LongTensor]
            The ending1 field.
        ending2 : Dict[str, torch.LongTensor]
            The ending2 field.
        ending3 : Dict[str, torch.LongTensor]
            The ending3 field.
        label : Optional[torch.IntTensor]
            The index of the correct ending.
        metadata : Optional[List[Dict[str, Any]]]
            Optional additional metadata.

        A dictionary containing:

        logits : torch.FloatTensor
            A batch_size x num_endings tensor giving the logit for each
            of the endings.

        probabilities : torch.FloatTensor
            A batch_size x num_endings tensor giving the probabilities
            for each ending.

        loss : Optional[torch.FloatTensor]
            The training loss.
        # pass the startphrase and endings through the initial text
        # embedding
        startphrase_initial = self.text_field_embedder(startphrase)
        ending0_initial = self.text_field_embedder(ending0)
        ending1_initial = self.text_field_embedder(ending1)
        ending2_initial = self.text_field_embedder(ending2)
        ending3_initial = self.text_field_embedder(ending3)

        # embed the startphrase and endings
        startphrase_embedding = self.startphrase_encoder(
            startphrase_initial, get_text_field_mask(startphrase))
        ending0_embedding = self.ending_encoder(ending0_initial,
        ending1_embedding = self.ending_encoder(ending1_initial,
        ending2_embedding = self.ending_encoder(ending2_initial,
        ending3_embedding = self.ending_encoder(ending3_initial,

        # take the dot product of the embeddings

        # first, stack the endings so that we get a batch x num_endings
        # x embedding_dim tensor, then add an extra dimension to the
        # startphrase batch so it's a batch x embedding_dim x 1 tensor,
        # and broadcast matrix multiplication across the last two
        # dimensions to get the dot products
        logits = torch.stack(
                ending0_embedding, ending1_embedding, ending2_embedding,

        # compute the probabilities
        probabilities = torch.nn.functional.softmax(logits, dim=-1)

        # compute the loss
        if label is not None:
            loss = self.xentropy(logits, label.long().view(-1))
            self.accuracy(logits, label)
            loss = None

        # return the output
        return {'logits': logits, 'probabilities': probabilities, 'loss': loss}

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self.accuracy.get_metric(reset)}

    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'SWAGExampleModel':
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, params.pop('text_field_embedder'))
        startphrase_encoder = Seq2VecEncoder.from_params(
        ending_encoder = Seq2VecEncoder.from_params(

        similarity = SimilarityFunction.from_params(params.pop('similarity'))

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        return cls(vocab=vocab,
Exemple #48
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        Initialize some data structures 
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
            text_field_embedder = None
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
            self._linear_projection_ELMO = linear_projection_ELMO
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
            Input_dimension_highway = text_field_embedder.get_output_dim()
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            highway_layer = Highway(Input_dimension_highway,
        highway_layer = TimeDistributed(highway_layer)
        self._highway_layer = highway_layer
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
            dropout_phrase_layer = lambda x: x
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
            dropout_modeling_passage = lambda x: x
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        ############## Span End Representation #####################
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
            dropout_span_end_encode = lambda x: x
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        Dropput last layers
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
            dropout_spans_output = lambda x: x
        self._dropout_spans_output = dropout_spans_output
        Checkings and accuracy
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        self._mask_lstms = cf_a.mask_lstms

        ################### Initialize parameters ##############################
        ####################### OPTIMIZER ################
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
Exemple #49
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30) -> None:
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()
        max_turn_length = 12

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))
        self.t = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 2, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers")


        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)
Exemple #50
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        embedding_dropout: float,
        pre_encode_feedforward: FeedForward,
        encoder: Seq2SeqEncoder,
        integrator: Seq2SeqEncoder,
        integrator_dropout: float,
        output_layer: Union[FeedForward, Maxout],
        elmo: Elmo = None,
        use_input_elmo: bool = False,
        use_integrator_output_elmo: bool = False,
        initializer: InitializerApplicator = InitializerApplicator(),
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = text_field_embedder
        if "elmo" in self._text_field_embedder._token_embedders.keys():
            raise ConfigurationError(
                "To use ELMo in the BiattentiveClassificationNetwork input, "
                "remove elmo from the text_field_embedder and pass an "
                "Elmo object to the BiattentiveClassificationNetwork and set the "
                "'use_input_elmo' and 'use_integrator_output_elmo' flags accordingly."
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._num_classes = self.vocab.get_vocab_size("labels")

        self._pre_encode_feedforward = pre_encode_feedforward
        self._encoder = encoder
        self._integrator = integrator
        self._integrator_dropout = nn.Dropout(integrator_dropout)

        self._elmo = elmo
        self._use_input_elmo = use_input_elmo
        self._use_integrator_output_elmo = use_integrator_output_elmo
        self._num_elmo_layers = int(self._use_input_elmo) + int(self._use_integrator_output_elmo)
        # Check that, if elmo is None, none of the elmo flags are set.
        if self._elmo is None and self._num_elmo_layers != 0:
            raise ConfigurationError(
                "One of 'use_input_elmo' or 'use_integrator_output_elmo' is True, "
                "but no Elmo object was provided upon construction. Pass in an Elmo "
                "object to use Elmo."

        if self._elmo is not None:
            # Check that, if elmo is not None, we use it somewhere.
            if self._num_elmo_layers == 0:
                raise ConfigurationError(
                    "Elmo object provided upon construction, but both 'use_input_elmo' "
                    "and 'use_integrator_output_elmo' are 'False'. Set one of them to "
                    "'True' to use Elmo, or do not provide an Elmo object upon construction."
            # Check that the number of flags set is equal to the num_output_representations of the Elmo object

            if len(self._elmo._scalar_mixes) != self._num_elmo_layers:
                raise ConfigurationError(
                    f"Elmo object has num_output_representations={len(self._elmo._scalar_mixes)}, but this "
                    f"does not match the number of use_*_elmo flags set to true. use_input_elmo "
                    f"is {self._use_input_elmo}, and use_integrator_output_elmo "
                    f"is {self._use_integrator_output_elmo}"

        # Calculate combined integrator output dim, taking into account elmo
        if self._use_integrator_output_elmo:
            self._combined_integrator_output_dim = (
                self._integrator.get_output_dim() + self._elmo.get_output_dim()
            self._combined_integrator_output_dim = self._integrator.get_output_dim()

        self._self_attentive_pooling_projection = nn.Linear(self._combined_integrator_output_dim, 1)
        self._output_layer = output_layer

        if self._use_input_elmo:
                text_field_embedder.get_output_dim() + self._elmo.get_output_dim(),
                "text field embedder output dim + ELMo output dim",
                "Pre-encoder feedforward input dim",
                "text field embedder output dim",
                "Pre-encoder feedforward input dim",

            "Pre-encoder feedforward output dim",
            "Encoder input dim",
            self._encoder.get_output_dim() * 3,
            "Encoder output dim * 3",
            "Integrator input dim",
        if self._use_integrator_output_elmo:
                self._combined_integrator_output_dim * 4,
                "(Integrator output dim + ELMo output dim) * 4",
                "Output layer input dim",
                self._integrator.get_output_dim() * 4,
                "Integrator output dim * 4",
                "Output layer input dim",

            "Output layer output dim",
            "Number of classes.",

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3),
        self.loss = torch.nn.CrossEntropyLoss()
Exemple #51
class DialogQA(Model):
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30) -> None:
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()
        max_turn_length = 12

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))
        self.t = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 2, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers")


        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            answer: Dict[str, torch.LongTensor],
            dialog: Dict[str, torch.LongTensor],
            previous_answer_appended: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            p1_answer_marker: torch.IntTensor = None,
            p2_answer_marker: torch.IntTensor = None,
            p3_answer_marker: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            followup_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        #question = previous_answer_appended
        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()"dialog shape token charcaters is %s %s", dialog['token_characters'].size(), dialog['elmo'].size())"question shape token charcaters is %s %s", question['token_characters'].size(), question['elmo'].size())
        batch_size, max_dia_count, max_dia_len, _ = dialog[
        total_qa_count = batch_size * max_qa_count
        qa_mask =, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)"11111111111 dialog is %s", dialog['token_characters'].shape)"11111111111 dialog is %s", dialog['elmo'].shape)

        embedded_dialog = self._text_field_embedder(dialog,
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
        embedded_dialog = embedded_dialog.reshape(
            total_qa_count, max_dia_len,
        embedded_question = self._variational_dropout(embedded_question)
        embedded_dialog = self._variational_dropout(embedded_dialog)
        embedded_passage = self._variational_dropout(
        passage_length = embedded_passage.size(1)"embedded question has shape %s", embedded_question.shape)"embedded dialog has shape %s", embedded_dialog.shape)
        question_mask = util.get_text_field_mask(question,
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        dialog_mask = util.get_text_field_mask(dialog,
        dialog_mask = dialog_mask.reshape(total_qa_count, max_dia_len)
        passage_mask = util.get_text_field_mask(passage).float()"dialog shape token charcaters is %s %s", dialog['token_characters'].size(), dialog['elmo'].size())"answer shape token charcaters is %s %s", answer['token_characters'].size(), answer['elmo'].size())"quesion shape token charcaters is %s %s", question['token_characters'].size(), question['elmo'].size())"previous answer shape token charcaters is %s %s", previous_answer_appended['token_characters'].size(), previous_answer_appended['elmo'].size())
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
            embedded_question =
                [embedded_question, question_num_marker_emb], dim=-1)

            # Append dialog number for dialog
            question_num_ind = util.get_range_vector(
                max_dia_count, util.get_device_of(embedded_dialog))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_dia_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_dia_len)
            question_num_marker_emb = self._question_num_marker(
            embedded_dialog =
                [embedded_dialog, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage =
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage =
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                    repeated_embedded_passage =
                        [repeated_embedded_passage, p3_answer_marker_emb],

            repeated_encoded_passage = self._variational_dropout(
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)"repeated encoded passage has shape %s", repeated_encoded_passage.shape)"embedded question has shape %s", embedded_question.shape)"question mask has shape %s",  question_mask.shape)"embedded dialog has shape %s", embedded_dialog.shape)"dialog mask has shape %s",  dialog_mask.shape)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))
        encoded_dialog = self._variational_dropout(
            self._phrase_layer(embedded_dialog, dialog_mask))"encoded_question is %s", encoded_question.shape)"encoded_dialog is %s", encoded_dialog.shape)"encoded_passage is %s", repeated_encoded_passage.shape)

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)"passage_question_similarity is %s", passage_question_similarity.shape)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)"question_mask is %s", question_mask.shape)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)"passage question vectors is %s", passage_question_vectors.shape)

        ############################# DIALOG SIMILARITY STUFF ################################################################
        dialog_question_similarity = self._matrix_attention(
            encoded_question, encoded_dialog)"dialog question similarity is %s", dialog_question_similarity.shape)"dialog_mask is %s", dialog_mask.shape)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        dialog_question_attention = util.masked_softmax(
            dialog_question_similarity, dialog_mask)"dialog question attention is %s", dialog_question_attention.shape)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        question_dialog_vectors = util.weighted_sum(encoded_dialog,
                                                    dialog_question_attention)"question_dialog_vectors is %s", question_dialog_vectors.shape)"encoded_question 111  %s", encoded_question.shape)"encoded_question 2222  %s", encoded_question.shape)"self._encoding_dim  %s", self._encoding_dim)"encoded_question 3333333  %s", encoded_question.shape)


        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        #if max_qa_count == 7 and batch_size == 21:
        #    sys.exit()
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        ############################################## Create qc for question tokens #########################################

        masked_similarity = util.replace_masked_values(
            dialog_question_similarity, dialog_mask.unsqueeze(1), -1e7)

        dialog_question_similarity = masked_similarity.max(
        dialog_question_attention = util.masked_softmax(
            dialog_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        dialog_question_vector = util.weighted_sum(encoded_question,
        tiled_dialog_question_vector = dialog_question_vector.unsqueeze(
            1).expand(total_qa_count, max_q_len, self._encoding_dim)

        encoded_question =[
            encoded_question, question_dialog_vectors,
            encoded_question * question_dialog_vectors,
            encoded_question * tiled_dialog_question_vector

        #encoded_question = self.t(encoded_question)
        encoded_question = F.relu(self._merge_atten(encoded_question))


        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage =[
            repeated_encoded_passage, passage_question_vectors,
            repeated_encoded_passage * passage_question_vectors,
            repeated_encoded_passage * tiled_question_passage_vector

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
        self_attention_matrix = self._self_attention(residual_layer,

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
        self_attention_vecs =[
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        residual_layer = F.relu(

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
  [final_merged_passage, start_rep], dim=-1),
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(

        span_start_logits = util.replace_masked_values(span_start_logits,
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,

        best_span = self._get_best_span_yesno_followup(span_start_logits,

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
            for i in range(0, total_qa_count):
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc =

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end =

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
            loss += nll_loss(F.log_softmax(_followup, dim=-1),

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                predicted_span = tuple(best_span_cpu[i * max_qa_count +

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]

                best_span_string = passage_str[start_offset:end_offset]
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            refs = [answer_texts[z] for z in idxes]
                                    squad_eval.f1_score, best_span_string,
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                self._official_f1(100 * f1_score)
        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'yesno': self._span_yesno_accuracy.get_metric(reset),
            'followup': self._span_followup_accuracy.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),

    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4),

        span_start_logits =
        span_end_logits =
        span_yesno_logits =
        span_followup_logits =
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Exemple #52
    def __init__(
            vocab: Vocabulary,
            span_encoder: Seq2SeqEncoder,
            reasoning_encoder: Seq2SeqEncoder,
            input_dropout: float = 0.1,
            hidden_dim_maxpool: int = 512,
            class_embs: bool = True,
            reasoning_use_obj: bool = True,
            reasoning_use_answer: bool = True,
            reasoning_use_question: bool = True,
            pool_reasoning: bool = True,
            pool_answer: bool = True,
            pool_question: bool = False,
            preload_path: str = "",
            initializer: InitializerApplicator = InitializerApplicator(),
        super(AttentionQA, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,

        self.rnn_input_dropout = TimeDistributed(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)
        self.BiLSTM = TimeDistributed(MYLSTM(1280, 512, 256))
        self.source_encoder = TimeDistributed(source_LSTM(768, 256))

        self.span_attention = BilinearMatrixAttention(
        self.span_attention_2 = BilinearMatrixAttention(

        self.obj_attention = BilinearMatrixAttention(

        self.obj_attention_2 = BilinearMatrixAttention(

        self._matrix_attention = DotProductMatrixAttention()
        #self._matrix_attention = MatrixAttention(similarity_function)

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        self.final_mlp_2 = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),

        self.answer_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.question_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.source_answer_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.source_question_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.image_BN = BatchNorm1d(512)
        self.final_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.final_mlp_linear = torch.nn.Sequential(torch.nn.Linear(512, 1))
        self.final_mlp_pool = torch.nn.Sequential(
            torch.nn.Linear(2560, 512),
            torch.nn.Dropout(input_dropout, inplace=False),

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        if preload_path is not None:
            preload = torch.load(preload_path)
            own_state = self.state_dict()
            for name, param in preload.items():
                #if name[0:8] == "_encoder":
                #    suffix = "._module."+name[9:]
                #"preload paramter {}".format("span_encoder"+suffix))
                #    own_state["span_encoder"+suffix].copy_(param)
                if name[0:4] == "LSTM":
                    suffix = "._module" + name[4:]
          "preload paramter {}".format("source_encoder" +
                    own_state["source_encoder" + suffix].copy_(param)
Exemple #53
class CosmosQATask(MultipleChoiceTask):
    """ Task class for CosmosQA Task. 
        adaptation of preprocessing from """
    def __init__(self, path, max_seq_len, name, **kw):
        super().__init__(name, **kw)
        self.path = path
        self.max_seq_len = max_seq_len

        self.train_data_text = None
        self.val_data_text = None
        self.test_data_text = None

        self.scorer1 = CategoricalAccuracy()
        self.scorers = [self.scorer1]
        self.val_metric = "%s_accuracy" % name
        self.val_metric_decreases = False
        self.n_choices = 4

    def load_data(self):
        """ Process the dataset located at path.  """
        self.train_data_text = self._load_csv(
            os.path.join(self.path, "train.csv"))
        self.val_data_text = self._load_csv(
            os.path.join(self.path, "valid.csv"))
        self.test_data_text = self._load_csv(
            os.path.join(self.path, "test_no_label.csv"))
        self.sentences = (self.train_data_text[0] + self.val_data_text[0] + [
            choice for choices in self.train_data_text[1] for choice in choices
        ] + [
            choice for choices in self.val_data_text[1] for choice in choices
        ])"\tFinished loading CosmosQA data.")

    def _load_csv(self, input_file):
        import csv

        with open(input_file, "r") as csv_file:
            reader = csv.DictReader(csv_file)
            records = [record for record in reader]

        contexts, choices, targs, id_str = [], [], [], []
        for record in records:
            question = record["question"]

            ans_choices = [
                record["answer" + str(i)] for i in range(self.n_choices)
            qa_tok_choices = [
                                      question + " " + ans_choices[i],
                for i in range(len(ans_choices))
            max_ans_len = max([len(tok) for tok in qa_tok_choices])
            context = tokenize_and_truncate(self._tokenizer_name,
                                            self.max_seq_len - max_ans_len)
            targ = int(record["label"]) if "label" in record else 0
            idx = record["id"]
        return [contexts, choices, targs, id_str]

    def process_split(
            self, split, indexers,
            model_preprocessing_interface) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """
        def _make_instance(context, choices, label, id_str):
            d = {}
            d["context_str"] = MetadataField(" ".join(context))
            if not model_preprocessing_interface.model_flags[
                d["context"] = sentence_to_text_field(
            for choice_idx, choice in enumerate(choices):
                inp = (model_preprocessing_interface.boundary_token_fn(
                    context, choice) if model_preprocessing_interface.
                       model_flags["uses_pair_embedding"] else
                d["choice%d" % choice_idx] = sentence_to_text_field(
                    inp, indexers)
                d["choice%d_str" % choice_idx] = MetadataField(
                    " ".join(choice))
            d["label"] = LabelField(label,
            d["id_str"] = MetadataField(id_str)
            return Instance(d)

        split = list(split)
        instances = map(_make_instance, *split)
        return instances

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        acc = self.scorer1.get_metric(reset)
        return {"accuracy": acc}
class QuarelSemanticParser(Model):
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    def __init__(
            vocab: Vocabulary,
            question_embedder: TextFieldEmbedder,
            action_embedding_dim: int,
            encoder: Seq2SeqEncoder,
            decoder_beam_search: BeamSearch,
            max_decoding_steps: int,
            attention: Attention,
            mixture_feedforward: FeedForward = None,
            add_action_bias: bool = True,
            dropout: float = 0.0,
            num_linking_features: int = 0,
            num_entity_bits: int = 0,
            entity_bits_output: bool = True,
            use_entities: bool = False,
            denotation_only: bool = False,
            # Deprecated parameter to load older models
            entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
            entity_similarity_mode: str = "dot_product",
            rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
   += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            raise ValueError("Invalid entity_similarity_mode: {}".format(

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(
                self._encoder_output_dim, self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions,
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
        self._first_attended_question = torch.nn.Parameter(

        self._decoder_step = LinkingTransitionFunction(

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[QuarelWorld],
            actions: List[List[ProductionRule]],
            entity_bits: torch.Tensor = None,
            denotation_target: torch.Tensor = None,
            target_action_sequences: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
        entity_bits : ``torch.Tensor``, optional (default=None)
            Tensor encoding bits for the world entities.
        denotation_target : ``torch.Tensor``, optional (default=None)
            If model's field ``denotation_only`` is True, this is the tensor target denotation.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
        metadata : List[Dict[str, Any]], optional (default=None).
            A dictionary of metadata for each batch element which has keys:
                question_tokens : ``List[str]``, optional.
                    The original string tokens in the question.
                world_extractions : ``nltk.Tree``, optional.
                    Extracted worlds from the question.
                answer_index : ``List[str]``, optional.
                    Index of the correct answer.

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(
                                        num_entities * num_entity_tokens,
                    torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(
                    -1, num_entities * num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size,
                                          num_entities * num_entity_tokens,
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(
                    num_question_tokens * num_entities * num_entity_tokens,
                question_entity_similarity = self._entity_similarity_layer(
                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(
                world, linking_scores.transpose(1, 2), question_mask,
            encoder_input = embedded_question
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input =[embedded_question, entity_bits], 2)
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs =[encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(
            loss = torch.nn.functional.cross_entropy(
                denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size,

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some ``s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_attended_question, encoder_output_list,

        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
            for i in range(batch_size)

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(
            action_history=[[] for _ in range(batch_size)],

            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
            return outputs

            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states =
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                        'question_tokens', []))
                if metadata is not None:
                        'world_extractions', {}))
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in best_action_indices
                        logical_form = world[i].get_logical_form(
                            action_strings, add_var_function=False)
                    except ParsingError:
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(
                            predicted_answer_index, answer_index)
                    score = math.exp(
                        best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs

    def _get_type_vector(
            worlds: List[QuarelWorld], num_entities: int,
            tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                    entity_type = 0

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(
            self, worlds: List[QuarelWorld], linking_scores: torch.FloatTensor,
            question_mask: torch.LongTensor,
            entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:

                if len(entity_indices) == 1:
                    # No entities of this type; move along...

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_entities - num_entities_in_instance)

            # (num_question_tokens, num_entities)
            probabilities =, dim=1)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    def _action_history_match(predicted: List[int],
                              targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(
            torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int,
                          target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0 / self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        if self._denotation_only:
            metrics = {
            metrics = {
                'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'lf_percent': self._has_logical_form.get_metric(reset),
        return metrics

    def _create_grammar_state(self, world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                        (production_rule_array[2], action_index))
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor =, dim=0)
            global_input_embeddings = self._action_embedder(
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                global_input_embeddings =
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,

        return GrammarStatelet([START_SYMBOL], translated_valid_actions,

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get(
                    'question_attention', [])
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
    def __init__(
            vocab: Vocabulary,
            question_embedder: TextFieldEmbedder,
            action_embedding_dim: int,
            encoder: Seq2SeqEncoder,
            decoder_beam_search: BeamSearch,
            max_decoding_steps: int,
            attention: Attention,
            mixture_feedforward: FeedForward = None,
            add_action_bias: bool = True,
            dropout: float = 0.0,
            num_linking_features: int = 0,
            num_entity_bits: int = 0,
            entity_bits_output: bool = True,
            use_entities: bool = False,
            denotation_only: bool = False,
            # Deprecated parameter to load older models
            entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
            entity_similarity_mode: str = "dot_product",
            rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
   += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            raise ValueError("Invalid entity_similarity_mode: {}".format(

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(
                self._encoder_output_dim, self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions,
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
        self._first_attended_question = torch.nn.Parameter(

        self._decoder_step = LinkingTransitionFunction(
Exemple #56
class ESIM(Model):
    This `Model` implements the ESIM sequence model described in [Enhanced LSTM for Natural Language Inference]
    by Chen et al., 2017.

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the `premise` and `hypothesis` `TextFields` we get as input to the
    encoder : `Seq2SeqEncoder`
        Used to encode the premise and hypothesis.
    similarity_function : `SimilarityFunction`
        This is the similarity function used when computing the similarity matrix between encoded
        words in the premise and words in the hypothesis.
    projection_feedforward : `FeedForward`
        The feedforward network used to project down the encoded and enhanced premise and hypothesis.
    inference_encoder : `Seq2SeqEncoder`
        Used to encode the projected premise and hypothesis for prediction.
    output_feedforward : `FeedForward`
        Used to prepare the concatenated premise and hypothesis for prediction.
    output_logit : `FeedForward`
        This feedforward network computes the output logits.
    dropout : `float`, optional (default=0.5)
        Dropout percentage to use.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        similarity_function: SimilarityFunction,
        projection_feedforward: FeedForward,
        inference_encoder: Seq2SeqEncoder,
        output_feedforward: FeedForward,
        output_logit: FeedForward,
        dropout: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")

            "text field embedding dim",
            "encoder input dim",
            encoder.get_output_dim() * 4,
            "encoder output dim",
            "projection feedforward input",
            "proj feedforward output dim",
            "inference lstm input dim",

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()


    def forward(  # type: ignore
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # Parameters

        premise : TextFieldTensors
            From a `TextField`
        hypothesis : TextFieldTensors
            From a `TextField`
        label : torch.IntTensor, optional (default = None)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        # Returns

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis,

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise,

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced =
                encoded_premise - attended_hypothesis,
                encoded_premise * attended_hypothesis,
        hypothesis_enhanced =
                encoded_hypothesis - attended_premise,
                encoded_hypothesis * attended_premise,

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(
        projected_enhanced_hypothesis = self._projection_feedforward(

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(
            projected_enhanced_hypothesis = self.rnn_input_dropout(
        v_ai = self._inference_encoder(projected_enhanced_premise,
        v_bi = self._inference_encoder(projected_enhanced_hypothesis,

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1),
        v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1),

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                            dim=1) / torch.sum(premise_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                            dim=1) / torch.sum(
                                hypothesis_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all =[v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self._accuracy.get_metric(reset)}
Exemple #57
    def __init__(
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        label_namespace: str = "labels",
        feedforward: Optional[FeedForward] = None,
        label_encoding: Optional[str] = None,
        include_start_end_transitions: bool = True,
        constrain_crf_decoding: bool = None,
        calculate_span_f1: bool = None,
        dropout: Optional[float] = None,
        verbose_metrics: bool = False,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        top_k: int = 1,
    ) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_tags = self.vocab.get_vocab_size(label_namespace)
        self.encoder = encoder
        self.top_k = top_k
        self._verbose_metrics = verbose_metrics
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.dropout = None
        self._feedforward = feedforward

        if feedforward is not None:
            output_dim = feedforward.get_output_dim()
            output_dim = self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(
            Linear(output_dim, self.num_tags))

        # if  constrain_crf_decoding and calculate_span_f1 are not
        # provided, (i.e., they're None), set them to True
        # if label_encoding is provided and False if it isn't.
        if constrain_crf_decoding is None:
            constrain_crf_decoding = label_encoding is not None
        if calculate_span_f1 is None:
            calculate_span_f1 = label_encoding is not None

        self.label_encoding = label_encoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3),
        self.calculate_span_f1 = calculate_span_f1
        if calculate_span_f1:
            if not label_encoding:
                raise ConfigurationError("calculate_span_f1 is True, but "
                                         "no label_encoding was specified.")
            self._f1_metric = SpanBasedF1Measure(vocab,

            "text field embedding dim",
            "encoder input dim",
        if feedforward is not None:
                "encoder output dim",
                "feedforward input dim",
Exemple #58
class DecomposableAttention(Model):
    This ``Model`` implements the Decomposable Attention model described in `"A Decomposable
    Attention Model for Natural Language Inference"
    by Parikh et al., 2016, with some optional enhancements before the decomposable attention
    actually happens.  Parikh's original model allowed for computing an "intra-sentence" attention
    before doing the decomposable entailment step.  We generalize this to any
    :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before
    computing entailment.

    The basic outline of this model is to get an embedded representation of each word in the
    premise and hypothesis, align words between the two, compare the aligned phrases, and make a
    final entailment decision based on this aggregated comparison.  Each step in this process uses
    a feedforward network to modify the representation.

    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    initializer : ``InitializerApplicator``
        We will use this to initialize the parameters in the model, calling ``initializer(self)``.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 initializer: InitializerApplicator,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None) -> None:
        super(DecomposableAttention, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = MatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")
        if aggregate_feedforward.get_output_dim() != self._num_labels:
            raise ConfigurationError(
                "Final output dimension (%d) must equal num labels (%d)" %
                (aggregate_feedforward.get_output_dim(), self._num_labels))

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input =
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input =
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input =[compared_premise, compared_hypothesis],
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'accuracy': self._accuracy.get_metric(reset),

    def predict_entailment(self, premise: TextField,
                           hypothesis: TextField) -> Dict[str, torch.Tensor]:
        Given a premise and a hypothesis sentence, predict the entailment relationship between
        them.  Note that in the paper, a null token was appended to each sentence, to allow for
        words to align to nothing in the other sentence.  If you've trained your model with a null
        token, you probably want to include it here, too.

        premise : ``TextField``
        hypothesis : ``TextField``

        A Dict containing:

        label_probs : torch.FloatTensor
            A tensor of shape ``(num_labels,)`` representing probabilities of the entailment label.
        instance = Instance({"premise": premise, "hypothesis": hypothesis})
        model_input = arrays_to_variables(instance.as_array_dict(),
        output_dict = self.forward(**model_input)

        # Remove batch dimension, as we only had one input.
        label_probs = output_dict["label_probs"].data.squeeze(0)
        return {'label_probs': label_probs.numpy()}

    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'DecomposableAttention':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)

        premise_encoder_params = params.pop("premise_encoder", None)
        if premise_encoder_params is not None:
            premise_encoder = Seq2SeqEncoder.from_params(
            premise_encoder = None

        hypothesis_encoder_params = params.pop("hypothesis_encoder", None)
        if hypothesis_encoder_params is not None:
            hypothesis_encoder = Seq2SeqEncoder.from_params(
            hypothesis_encoder = None

        attend_feedforward = FeedForward.from_params(
        similarity_function = SimilarityFunction.from_params(
        compare_feedforward = FeedForward.from_params(
        aggregate_feedforward = FeedForward.from_params(
        initializer = InitializerApplicator.from_params(
            params.pop("initializer", []))

        return cls(vocab=vocab,
Exemple #59
    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 entity_embedder: TextFieldEmbedder,
                 relation_embedder: TextFieldEmbedder,
                 knowledge_graph_path: str,
                 use_shortlist: bool,
                 hidden_size: int,
                 num_layers: int,
                 cutoff: int = 30,
                 tie_weights: bool = False,
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(KglmDisc, self).__init__(vocab)

        # We extract the `Embedding` layers from the `TokenEmbedders` to apply dropout later on.
        # pylint: disable=protected-access
        self._token_embedder = token_embedder._token_embedders['tokens']
        self._entity_embedder = entity_embedder._token_embedders['entity_ids']
        self._relation_embedder = relation_embedder._token_embedders['relations']
        self._recent_entities = RecentEntities(cutoff=cutoff)
        self._knowledge_graph_lookup = KnowledgeGraphLookup(knowledge_graph_path, vocab=vocab)
        self._use_shortlist = use_shortlist
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._cutoff = cutoff
        self._tie_weights = tie_weights

        # Dropout
        self._locked_dropout = LockedDropout()
        self._dropout = dropout
        self._dropouth = dropouth
        self._dropouti = dropouti
        self._dropoute = dropoute
        self._wdrop = wdrop

        # Regularization strength
        self._alpha = alpha
        self._beta = beta

        # RNN Encoders.
        entity_embedding_dim = entity_embedder.get_output_dim()
        token_embedding_dim = token_embedder.get_output_dim()
        self.entity_embedding_dim = entity_embedding_dim
        self.token_embedding_dim = token_embedding_dim

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = token_embedding_dim
                input_size = hidden_size
            if (i == num_layers - 1):
                output_size = token_embedding_dim + 2 * entity_embedding_dim
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        # Various linear transformations.
        self._fc_mention_type = torch.nn.Linear(

        if not use_shortlist:
            self._fc_new_entity = torch.nn.Linear(

            if tie_weights:
                self._fc_new_entity.weight = self._entity_embedder.weight

        self._overlap_weight = torch.nn.Parameter(torch.tensor([1.]))

        self._state: Optional[Dict[str, Any]] = None

        # Metrics
        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))
        self._avg_mention_type_loss = Average()
        self._avg_new_entity_loss = Average()
        self._avg_knowledge_graph_entity_loss = Average()
        self._new_mention_f1 =  F1Measure(positive_label=1)
        self._kg_mention_f1 = F1Measure(positive_label=2)
        self._new_entity_accuracy = CategoricalAccuracy()
        self._new_entity_accuracy20 = CategoricalAccuracy(top_k=20)
        self._parent_ppl = Ppl()
        self._relation_ppl = Ppl()
