Пример #1
0
    def _get_initial_state_and_scores(
            self,
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            example_lisp_string: List[str] = None,
            add_world_to_initial_state: bool = False,
            checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(
            entity_type_embeddings + projected_neighbor_embeddings)

        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (
            embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (
            embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = Variable(
            encoder_outputs.data.new(batch_size,
                                     self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(
            embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            world=initial_state_world,
            example_lisp_string=example_lisp_string,
            checklist_state=checklist_states,
            debug_info=None)
        return {
            "initial_state": initial_state,
            "linking_scores": linking_scores,
            "feature_scores": feature_scores,
            "similarity_scores": question_entity_similarity_max_score
        }
Пример #2
0
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions,
                                        embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal(self._first_action_embedding)
        torch.nn.init.normal(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(),
                               question_embedder.get_output_dim(),
                               "entity word average embedding dim",
                               "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types,
                                            self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None
Пример #3
0
 def __init__(self, input_dim):
     super(SelfAttentiveSpanExtractor, self).__init__()
     self._input_dim = input_dim
     self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1))
    def forward(
            self,
            metadata: List[Dict[str, Any]],
            event: Dict[str, torch.Tensor],
            information: Dict[str, torch.Tensor] = None,
            year: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:

        # shape: (batch_size, max_input_sequence_length)
        mask = get_text_field_mask(event)

        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        self.word_embeddings.eval(
        )  # This disables dropout in the embeddings, like pretrained BERT embedder
        embeddings = self.word_embeddings(event)

        assert mask.shape[0] == embeddings.shape[0]
        # BERT doesn't enforce this (because of wordpiece tokenization). The mask is word-based, and the embeddings are wordpiece-based.
        # But we then set the encoder to be a BERTPooler and are unaffected by this.
        #assert mask.shape[1] == embeddings.shape[1]
        assert embeddings.shape[2] == self.word_embeddings.get_output_dim()

        # shape: (batch_size, encoder_output_dim)
        if torch.sum(mask).int().item() == 0:
            # There is no event text associated with any of the datapoints, so we return a vector of zeroes
            batch_size = mask.shape[0]
            encoder_out = embeddings.new_zeros(
                (batch_size, self.encoder.get_output_dim()))
        else:
            encoder_out = self.encoder(embeddings, mask)

        assert mask.shape[0] == encoder_out.shape[0]
        assert encoder_out.shape[1] == self.encoder.get_output_dim()

        if self.use_ci:
            # Since the CI is a ListField of TextFields, we need to take care to pass num_wrapping_dim
            # (our tensors have an extra dimension)

            # shape: (batch_size, max_ci_sent_per_inp_in_batch, max_ci_sequence_length, encoder_input_dim)
            embeddings_ci = self.word_embeddings(information,
                                                 num_wrapping_dims=1)

            assert embeddings_ci.shape[0] == mask.shape[0]
            assert embeddings_ci.shape[
                3] == self.word_embeddings.get_output_dim()

            # shape: (batch_size, max_ci_sent_per_inp_in_batch, max_ci_sequence_length)
            mask_ci = get_text_field_mask(information, num_wrapping_dims=1)

            assert mask_ci.shape[0] == embeddings_ci.shape[0]
            assert mask_ci.shape[1] == embeddings_ci.shape[1]
            # BERT doesn't enforce this (because of wordpiece tokenization). The mask is word-based, and the embeddings are wordpiece-based.
            # But we then set the encoder to be a BERTPooler and are unaffected by this.
            #assert mask_ci.shape[2] == embeddings_ci.shape[2]

            # shape: (batch_size, encoder_output_dim)
            #
            # Here we need to be careful, as AllenNLP crashes if all datapoints in the batch are devoid of CI.
            # We handle this case by returning a vector of zeroes for all datapoints. This is compatible with how
            # AllenNLP handles the encoding of sequences of length zero (i.e. cases were there is no CI to encode,
            # because of padding or because there just isn't any CI associated with the datapoint).
            if torch.sum(mask_ci).int().item() == 0:
                # There is no CI associated with any of the datapoints. We return a vector of zeroes
                batch_size = mask_ci.shape[0]
                encoder_out_ci_condensed = embeddings_ci.new_zeros(
                    (batch_size, self.encoder.get_output_dim()))

                assert encoder_out_ci_condensed.shape[0] == mask_ci.shape[0]
                assert encoder_out_ci_condensed.shape[
                    1] == self.encoder.get_output_dim()
                assert len(encoder_out_ci_condensed.shape) == 2
            else:
                # shape: (batch_size, max_ci_sent_per_inp_in_batch, encoder_output_dim)
                #
                # Given a module and some tensor, TimeDistributed transforms the N-dimensional tensor into an (N-1)-dimensional tensor,
                # by merging the first 2 dimensions, then passes this tensor to the module, and then splits the merged dimension.
                # Hence we can apply our listfield of textfields to our encoder that only takes text fields.
                # TODO: Right now we are using the exact same encoder as for the event. Maybe we should use a different one?
                encoder_out_ci = TimeDistributed(self.encoder)(embeddings_ci,
                                                               mask_ci)

                assert encoder_out_ci.shape[0] == mask_ci.shape[0]
                assert encoder_out_ci.shape[1] == mask_ci.shape[1]
                assert encoder_out_ci.shape[2] == self.encoder.get_output_dim()

                # We have the encoding for each CI sentence. Now we reduce it to a single vector.
                # We do so by applying an attention mechanism, and then computing a weighted average
                # of the CI encoder results according to the attention weights.
                #
                # If there is no CI associated with a datapoint, we just return a vector of zeros for that CI embedding.
                # We do this by exploiting the fact that the CI encoding when there isn't any CI is just a vector of zeroes.
                # So for these cases we programatically set the attention of the first item to 1, thus giving the zeroes
                # vector complete attention, which after the weighted average means that the weighted vector is still
                # a vector of zeroes.

                # shape: (batch_size, max_ci_sent_per_inp_in_batch)
                encoder_out_ci_valid = mask_ci[:, :, 0]

                assert encoder_out_ci_valid.shape[0] == mask_ci.shape[0]
                assert encoder_out_ci_valid.shape[1] == mask_ci.shape[1]
                assert len(encoder_out_ci_valid.shape) == 2

                # shape: (batch_size, 1, max_ci_sent_per_inp_in_batch)
                # WARNING: When max_ci_sent_inp_in_batch == 1, and attention is linear, this crashes. Probably because
                #          of broadcasting in the wrong direction when calculating attention. This might be an AllenNLP bug.
                #          other forms of attention work fine.
                encoder_out_ci_attention_weights = self.ci_attention(
                    encoder_out, encoder_out_ci,
                    encoder_out_ci_valid).unsqueeze(dim=1)

                assert encoder_out_ci_attention_weights.shape[
                    0] == mask_ci.shape[0]
                assert encoder_out_ci_attention_weights.shape[1] == 1
                assert encoder_out_ci_attention_weights.shape[
                    2] == mask_ci.shape[1]
                assert len(encoder_out_ci_attention_weights.shape) == 3

                # Now we calculate the weighted encoding.
                # shape: (batch_size, encoder_output_dim)
                encoder_out_ci_condensed = encoder_out_ci_attention_weights.bmm(
                    encoder_out_ci).squeeze(dim=1)
                assert encoder_out_ci_condensed.shape[0] == mask_ci.shape[0]
                assert encoder_out_ci_condensed.shape[
                    1] == self.encoder.get_output_dim()
                assert len(encoder_out_ci_condensed.shape) == 2

            # shape: (batch_size, encoder_output_dim * 2)
            encoder_out_ci_event_combined = torch.cat(
                [encoder_out, encoder_out_ci_condensed], dim=1)

            assert encoder_out_ci_event_combined.shape[0] == mask_ci.shape[0]
            assert encoder_out_ci_event_combined.shape[
                1] == self.encoder.get_output_dim() * 2

            # shape: if "reg" (batch_size,); if "cls" (batch_size, num_classes)
            predicted_output = self.vec2year(
                encoder_out_ci_event_combined).squeeze(dim=1)
        else:
            # shape: if "reg" (batch_size,); if "cls" (batch_size, num_classes)
            predicted_output = self.vec2year(encoder_out).squeeze(dim=1)

        # Compute the predictions
        if self.predictor_type == "reg":
            # We only need to calculate the predicted years in regression
            # shape: (batch_size,)
            predicted_year = predicted_output * self.normalize_outputs_std + self.normalize_outputs_mean
            predicted_year = predicted_year.float()
            assert predicted_year.shape[0] == mask.shape[0]

            output = {"year": predicted_year}
        elif self.predictor_type == "cls":
            # For classification we compute the logits, and from that the predicted years
            # shape: (batch_size, num_classes)
            predicted_logit = predicted_output

            assert predicted_logit.shape[0] == mask.shape[0]
            assert predicted_logit.shape[
                1] == self.year_max - self.year_min + 1

            # shape: (batch_size,)
            predicted_year = predicted_logit.argmax(dim=1) + self.year_min
            predicted_year = predicted_year.float()
            assert predicted_year.shape[0] == mask.shape[0]

            output = {"year": predicted_year, "logit": predicted_logit}

        if year is not None:
            # shape: (batch_size,)
            year = year * self.normalize_outputs_std + self.normalize_outputs_mean
            year = year.float()
            assert year.shape[0] == mask.shape[0]

            # Compute loss
            if self.predictor_type == "reg":
                # MSE loss calculation
                output["loss"] = F.mse_loss(predicted_year, year)
            elif self.predictor_type == "cls":
                # Cross-entropy loss calculation
                year_index = year - self.year_min
                year_index = year_index.long()
                output["loss"] = F.cross_entropy(predicted_logit, year_index)

            # Compute metrics
            self.mae(predicted_year, year)
            self.kendall_tau(predicted_year, year)
            self.exact_match(predicted_year, year)
            self.under_20y(predicted_year, year)
            self.under_50y(predicted_year, year)

        return output
Пример #5
0
    def __init__(
        self,
        vocab: Vocabulary,
        question_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        decoder_beam_search: BeamSearch,
        max_decoding_steps: int,
        attention: Attention,
        mixture_feedforward: FeedForward = None,
        add_action_bias: bool = True,
        dropout: float = 0.0,
        num_linking_features: int = 0,
        num_entity_bits: int = 0,
        entity_bits_output: bool = True,
        use_entities: bool = False,
        denotation_only: bool = False,
        # Deprecated parameter to load older models
        entity_encoder: Seq2VecEncoder = None,
        entity_similarity_mode: str = "dot_product",
        rule_namespace: str = "rule_labels",
    ) -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = TimeDistributed(
                torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(
                self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(
                self._encoder_output_dim, self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions,
                                            embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(
            torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(
            encoder_output_dim=self._encoder_output_dim,
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            add_action_bias=self._add_action_bias,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout,
        )
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRuleArray]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = question_entity_similarity_max_score + feature_scores
        else:
            # The linking score is computed as a linear projection of two terms. The first is the maximum
            # similarity score over the entity's words and the question token. The second is the maximum
            # similarity over the words in the entity's neighbors and the question token.
            #   The second term, projected_question_neighbor_similarity, is useful when
            # a column needs to be selected. For example, the question token might have no similarity
            # with the column name, but is similar with the cells in the column.
            #   Note that projected_question_neighbor_similarity is intended to capture the same information
            # as the related_column feature.
            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               debug_info=None)
        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if self._linking_params is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form, example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs
    def forward(
        self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int = 0, **kwargs
    ) -> torch.Tensor:
        embedder_keys = self._token_embedders.keys()
        input_keys = text_field_input.keys()

        # Check for unmatched keys
        if not self._allow_unmatched_keys:
            if embedder_keys < input_keys:
                # token embedder keys are a strict subset of text field input keys.
                message = (
                    f"Your text field is generating more keys ({list(input_keys)}) "
                    f"than you have token embedders ({list(embedder_keys)}. "
                    f"If you are using a token embedder that requires multiple keys "
                    f"(for example, the OpenAI Transformer embedder or the BERT embedder) "
                    f"you need to add allow_unmatched_keys = True "
                    f"(and likely an embedder_to_indexer_map) to your "
                    f"BasicTextFieldEmbedder configuration. "
                    f"Otherwise, you should check that there is a 1:1 embedding "
                    f"between your token indexers and token embedders."
                )
                raise ConfigurationError(message)

            elif self._token_embedders.keys() != text_field_input.keys():
                # some other mismatch
                message = "Mismatched token keys: %s and %s" % (
                    str(self._token_embedders.keys()),
                    str(text_field_input.keys()),
                )
                raise ConfigurationError(message)

        embedded_representations = []
        keys = sorted(embedder_keys)
        for key in keys:
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            forward_params = inspect.signature(embedder.forward).parameters
            forward_params_values = {}
            for param in forward_params.keys():
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)
            # If we pre-specified a mapping explictly, use that.
            # make mypy happy
            tensors: Union[List[Any], Dict[str, Any]] = None
            if self._embedder_to_indexer_map is not None:
                indexer_map = self._embedder_to_indexer_map[key]
                if isinstance(indexer_map, list):
                    # If `indexer_key` is None, we map it to `None`.
                    tensors = [
                        (text_field_input[indexer_key] if indexer_key is not None else None)
                        for indexer_key in indexer_map
                    ]
                    token_vectors = embedder(*tensors, **forward_params_values)
                elif isinstance(indexer_map, dict):
                    tensors = {
                        name: text_field_input[argument] for name, argument in indexer_map.items()
                    }
                    token_vectors = embedder(**tensors, **forward_params_values)
                else:
                    raise NotImplementedError
            else:
                # otherwise, we assume the mapping between indexers and embedders
                # is bijective and just use the key directly.
                tensors = [text_field_input[key]]
                token_vectors = embedder(*tensors, **forward_params_values)
            embedded_representations.append(token_vectors)
        return torch.cat(embedded_representations, dim=-1)
Пример #8
0
    def __init__(
        self,
        vocab: Vocabulary,
        use_glove: bool,
        use_bert: bool,
        bert_trainable: bool,
        bert_name: str,
        mention_embedder: TextFieldEmbedder,
        dialog_context: Seq2SeqEncoder,
        fact_ranker: FactRanker,
        dropout_prob: float,
        sender_emb_size: int,
        act_emb_size: int,
        fact_loss_weight: float,
        fact_pos_weight: float,
        utter_embedder: TextFieldEmbedder = None,
        utter_context: Seq2VecEncoder = None,
        disable_known_entities: bool = False,
        disable_dialog_acts: bool = False,
        disable_likes: bool = False,
        disable_facts: bool = False,
    ):
        super().__init__(vocab)
        self._disable_known_entities = disable_known_entities
        self._disable_dialog_acts = disable_dialog_acts
        self._clamp_dialog_acts = Clamp(should_clamp=disable_dialog_acts)
        self._disable_likes = disable_likes
        self._clamp_likes = Clamp(should_clamp=disable_likes)
        self._disable_facts = disable_facts
        self._clamp_facts = Clamp(should_clamp=disable_facts)

        self._fact_loss_weight = fact_loss_weight
        self._fact_pos_weight = fact_pos_weight

        self._sender_emb_size = sender_emb_size
        self._sender_emb = nn.Embedding(2, sender_emb_size)
        # Easier to use a matrix as embeddings, given the input format
        self._act_embedder = nn.Linear(vocab.get_vocab_size(DIALOG_ACT_LABELS),
                                       act_emb_size,
                                       bias=False)
        self._mention_embedder = mention_embedder

        if int(use_glove) + int(use_bert) != 1:
            raise ValueError("Cannot use bert and glove together")

        self._use_glove = use_glove
        self._use_bert = use_bert
        self._bert_trainable = bert_trainable
        self._bert_name = bert_name
        self._utter_embedder = utter_embedder
        self._utter_context = utter_context
        # Bert encoder is embedder + context
        if use_bert:
            # Not trainable for now
            self._bert_encoder = BertEncoder(self._bert_name,
                                             requires_grad=bert_trainable)
            self._dist_utter_context = None
            self._utter_dim = self._bert_encoder.get_output_dim()
        else:
            self._bert_encoder = None
            self._dist_utter_context = TimeDistributed(self._utter_context)
            self._utter_dim = self._utter_context.get_output_dim()
        self._dialog_context = dialog_context
        self._fact_ranker = fact_ranker
        # Easier to code as cross entropy with two classes
        # Likes are per message, for only assistant messages
        self._like_classifier = nn.Linear(
            self._dialog_context.get_output_dim(), 2)
        self._like_accuracy = CategoricalAccuracy()
        self._like_loss_metric = Average()

        # Transform the word_dim entity reps to hidden_dim
        self._focus_net = nn.Sequential(
            nn.Linear(
                self._mention_embedder.get_output_dim(),
                self._dialog_context.get_output_dim(),
            ),
            GeLU(),
        )
        self._known_net = nn.Sequential(
            nn.Linear(
                self._mention_embedder.get_output_dim(),
                self._dialog_context.get_output_dim(),
            ),
            GeLU(),
            Clamp(should_clamp=disable_known_entities),
        )
        # If we don't use known, then disable gradient to it
        self._known_net.requires_grad = not disable_known_entities

        # Dialog acts are per message, for all messages
        # This network predicts the dialog act of the current message
        # for both student and teacher
        self._da_classifier = nn.Sequential(
            nn.Linear(
                self._utter_dim + self._dialog_context.get_output_dim(),
                self._dialog_context.get_output_dim(),
            ),
            GeLU(),
            nn.Linear(
                self._dialog_context.get_output_dim(),
                vocab.get_vocab_size(DIALOG_ACT_LABELS),
            ),
        )
        self._da_bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        self._da_f1_metric = MultilabelMicroF1()
        self._da_loss_metric = Average()

        # This network predicts what the next action should be
        # It predicts for user and assistant since there isn't a real
        # reason to restrict that
        self._policy_classifier = nn.Sequential(
            nn.Linear(
                self._dialog_context.get_output_dim(),
                self._dialog_context.get_output_dim(),
            ),
            GeLU(),
            nn.Linear(
                self._dialog_context.get_output_dim(),
                vocab.get_vocab_size(DIALOG_ACT_LABELS),
            ),
        )
        self._policy_bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        self._policy_f1_metric = MultilabelMicroF1()
        self._policy_loss_metric = Average()

        self._fact_mrr = MeanReciprocalRank()
        self._fact_loss_metric = Average()
        self._dropout_prob = dropout_prob
        self._dropout = nn.Dropout(dropout_prob)
        # Fact use is much less prevalant, about 9 times less so, so factor that in
        self._fact_bce_loss = torch.nn.BCEWithLogitsLoss(
            reduction="none", pos_weight=torch.Tensor([self._fact_pos_weight]))
Пример #9
0
 def __init__(self, input_dim: int) -> None:
     super().__init__()
     self._input_dim = input_dim
     self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1))
Пример #10
0
    def forward(
        self,
        text_field_input: Dict[str, torch.Tensor],
        classifier_name: str = "@pretrain@",
        num_wrapping_dims: int = 0,
    ) -> torch.Tensor:
        # if self._token_embedders.keys() != text_field_input.keys():
        #     message = "Mismatched token keys: %s and %s" % (
        #         str(self._token_embedders.keys()),
        #         str(text_field_input.keys()),
        #     )
        #     raise ConfigurationError(message)
        embedded_representations = []
        keys = sorted(text_field_input.keys())
        for key in keys:
            # We handle count2vec indices as a special case below.
            if key.startswith("count2vec"): 
                continue

            tensor = text_field_input[key]
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)
            token_vectors = embedder(tensor)

            # Changed vs original:
            # If we want separate scalars/task, figure out which representation to use, since
            # embedder create a representation for _all_ sets of scalars. This can be optimized
            # with more wrapper classes but we compute all of them for now.
            # The shared ELMo scalar weights version all use the @pretrain@ embeddings.
            # There must be at least as many ELMo representations as the highest index in
            # self.task_map, otherwise indexing will fail.
            if key == "elmo" and not self.elmo_chars_only:
                if self.sep_embs_for_skip:
                    token_vectors = token_vectors["elmo_representations"][
                        self.task_map[classifier_name]
                    ]
                else:
                    token_vectors = token_vectors["elmo_representations"][
                        self.task_map["@pretrain@"]
                    ]

            # optional projection step that we are ignoring.
            embedded_representations.append(token_vectors)

        if "count2vec_indices" in keys and "count2vec_values" in keys:
            count2vec_indices = text_field_input["count2vec_indices"]
            count2vec_values = text_field_input["count2vec_values"]

            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format("count2vec"))
            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)
            token_vectors = embedder(count2vec_indices, count2vec_values)
            # optional projection step that we are ignoring.
            embedded_representations.append(token_vectors)

        return torch.cat(embedded_representations, dim=-1)
Пример #11
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 final_feedforward: FeedForward,
                 coverage_loss: CoverageLoss = None,
                 contextualize_pair_comparators: bool = False,
                 pair_context_encoder: Seq2SeqEncoder = None,
                 pair_feedforward: FeedForward = None,
                 optimize_coverage_for: List = ["entailment", "neutral"],
                 dropout: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        self._label2idx = self.vocab.get_token_to_index_vocabulary('labels')

        self._text_field_embedder = text_field_embedder

        self._entailment_comparator_layer_1 = EsimComparatorLayer1(
            encoder, dropout)
        self._entailment_comparator_layer_2 = EsimComparatorLayer2(
            similarity_function)

        self._td_entailment_comparator_layer_1 = TimeDistributed(
            self._entailment_comparator_layer_1)
        self._td_entailment_comparator_layer_2 = TimeDistributed(
            self._entailment_comparator_layer_2)

        self._entailment_comparator_layer_3plus_local = EsimComparatorLayer3Plus(
            projection_feedforward, inference_encoder, output_feedforward,
            dropout)
        self._td_entailment_comparator_layer_3plus_local = TimeDistributed(
            self._entailment_comparator_layer_3plus_local)

        self._entailment_comparator_layer_3plus_global = copy.deepcopy(
            self._entailment_comparator_layer_3plus_local)

        self._contextualize_pair_comparators = contextualize_pair_comparators

        if not self._contextualize_pair_comparators:
            self._output_logit = output_logit
            self._td_output_logit = TimeDistributed(self._output_logit)

        self._final_feedforward = final_feedforward
        self._td_final_feedforward = TimeDistributed(final_feedforward)

        linear = torch.nn.Linear(
            2 * self._entailment_comparator_layer_3plus_local.get_output_dim(),
            self._final_feedforward.get_input_dim())
        self._local_global_projection = torch.nn.Sequential(
            linear, torch.nn.ReLU())

        if self._contextualize_pair_comparators:
            self._pair_context_encoder = pair_context_encoder
            self._td_pair_feedforward = TimeDistributed(pair_feedforward)

        self._coverage_loss = coverage_loss