示例#1
0
 def _get_pred_mask(self,
                    output_symbols: torch.LongTensor) -> torch.BoolTensor:
     batch_size, max_pred_len = output_symbols.size()
     pred_mask = output_symbols.new_ones((batch_size, max_pred_len),
                                         dtype=torch.bool)
     for i in range(1, max_pred_len):
         pred_mask[:, i] = pred_mask[:, i - 1] * ~(output_symbols[:, i - 1]
                                                   == self.eos_id)
     return pred_mask
示例#2
0
    def __call__(self, tokens: torch.LongTensor, prefix_mask: torch.LongTensor):
        padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool)
        for pad in self.excludes:
            padding_mask &= (tokens != pad)
        padding_mask &= prefix_mask  # Only mask prefixes since the others won't be attended
        # Create a uniformly random mask selecting either the original words or OOV tokens
        dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < self.mask_prob)
        oov_mask = dropout_mask & padding_mask

        oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(self.oov)

        result = torch.where(oov_mask, oov_fill, tokens)
        return result, oov_mask
示例#3
0
    def contains(self, batch: torch.LongTensor) -> torch.BoolTensor:
        """
        Check whether a triple is contained.

        :param batch: shape (batch_size, 3)
            The batch of triples.

        :return: shape: (batch_size,)
            The result. False guarantees that the element was not contained in the indexed triples. True can be
            erroneous.
        """
        result = batch.new_ones(batch.shape[0], dtype=torch.bool)
        for i in self.probe(batch):
            result &= self.bit_array[i]
        return result
示例#4
0
    def token_dropout(tokens: torch.LongTensor,
                      oov_token: int,
                      exclude_tokens: List[int],
                      p: float = 0.2,
                      training: float = True) -> torch.LongTensor:
        """During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p``
        
        Adopted from https://github.com/Hyperparticle/udify

        Args:
          tokens: The current batch of padded sentences with word ids
          oov_token: The mask token
          exclude_tokens: The tokens for padding the input batch
          p: The probability a word gets mapped to the unknown token
          training: Applies the dropout if set to ``True``
          tokens: torch.LongTensor: 
          oov_token: int: 
          exclude_tokens: List[int]: 
          p: float:  (Default value = 0.2)
          training: float:  (Default value = True)

        Returns:
          A copy of the input batch with token dropout applied

        """
        if training and p > 0:
            # This creates a mask that only considers unpadded tokens for mapping to oov
            padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool)
            for pad in exclude_tokens:
                padding_mask &= (tokens != pad)

            # Create a uniformly random mask selecting either the original words or OOV tokens
            dropout_mask = (tokens.new_empty(tokens.size(),
                                             dtype=torch.float).uniform_() < p)
            oov_mask = dropout_mask & padding_mask

            oov_fill = tokens.new_empty(tokens.size(),
                                        dtype=torch.long).fill_(oov_token)

            result = torch.where(oov_mask, oov_fill, tokens)

            return result
        else:
            return tokens
示例#5
0
文件: editor.py 项目: isomap/factedit
    def _action_to_token(self, action_tokens: torch.LongTensor,
                         draft_tokens: torch.LongTensor) -> torch.LongTensor:
        predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1))
        draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1))

        predicted_tokens = action_tokens.new_full((action_tokens.size()),
                                                  self.END)

        for act_step in action_tokens.t():
            # KEEP, DELETE, COPY, ADD (other)
            keep_mask = act_step == self.KEEP
            drop_mask = act_step == self.DROP
            add_mask = ~(keep_mask | drop_mask)

            predicted_tokens.scatter_(1, predicted_pointer,
                                      draft_tokens.gather(1, draft_pointer))
            predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter(
                1, predicted_pointer[add_mask],
                act_step[add_mask].unsqueeze(1))

            draft_pointer[keep_mask | drop_mask] += 1
            predicted_pointer[~drop_mask] += 1
        return predicted_tokens
    def _parse(
        self,
        encoded_text: torch.Tensor,
        mask: torch.LongTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
示例#7
0
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict
示例#8
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
        #l_ext_input = encoded_text
        l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1)
        lemma_logits = self._lemma_output(l_ext_input)
        predicted_lemmas = lemma_logits.argmax(-1)

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ
        lemma_probs = torch.nn.functional.softmax(lemma_logits, -1)
        top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt]
        #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt]
        top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices)
        #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices)

        # АНАЛОГИЧНО ДЛЯ ГРАММЕМ
        gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1)
        top_gramms_indices = (
            -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt]
        top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            synt_prediction, benrg = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
            predicted_heads, predicted_head_tags = synt_prediction

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ
        benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute(
            0, 2, 1)  # склеивает тип синт. отношения с индексом родителя
        top_deprels_indices = (-benrgf).argsort(
            -1)[:, :, :self.TopNCnt]  # отбираем наилучшие комбинации
        top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices)
        seqlen = benrg.shape[2]
        top_heads = torch.fmod(top_deprels_indices, seqlen)
        top_deprels_indices = torch.div(top_deprels_indices,
                                        seqlen)  # torch.floor не срабатывает

        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
            "top_lemmas": top_lemmas_indices,
            "top_lemmas_prob": top_lemmas_prob,
            "top_gramms": top_gramms_indices,
            "top_gramms_prob": top_gramms_prob,
            "top_heads": top_heads,
            "top_deprels": top_deprels_indices,
            "top_deprels_prob": top_deprels_prob,
        }

        return output_dict
示例#9
0
 def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor:
     valid_tokens_mask = tensor.new_ones(tensor.size(), dtype=torch.uint8)
     for index in self._ignored_indices:
         valid_tokens_mask = valid_tokens_mask & (tensor != index).byte()
     return valid_tokens_mask
示例#10
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        lemma_logits = self._lemma_output(encoded_text)
        predicted_lemmas = lemma_logits.argmax(-1)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
        }

        return output_dict
示例#11
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        # добавим измеремение, которое каждому выходу энкодера ставит в соответствие три его копии
        encoded_text_3 = encoded_text
        encoded_text_3 = torch.unsqueeze(encoded_text_3, 2)
        encoded_text_3 = encoded_text_3.repeat(1,1,3,1)
        # пропустим три копии вектора (с выхода энкодера) через lstm
        seq_len = encoded_text.size()[1]
        emb_div_val = encoded_text.size()[2]
        multi_triplets = torch.reshape(encoded_text_3, (-1, 3, emb_div_val))
        label_variants, _ = self.multilabeler_lstm(multi_triplets)
        batched_label_variants = torch.reshape(label_variants, (-1, seq_len, 3, emb_div_val))
#         # отладочный вывод
#         print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#         print( "ITLOG: encoded_text.size() = {}".format(encoded_text.size()) )
#         print( "ITLOG: encoded_text_3.size() = {}".format(encoded_text_3.size()) )
#         print( "ITLOG: multi_triplets.size() = {}".format(multi_triplets.size()) )
#         print( "ITLOG: label_variants.size() = {}".format(label_variants.size()) )
#         print( "ITLOG: batched_label_variants.size() = {}".format(batched_label_variants.size()) )
#         print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")

#        grammar_value_logits = self._gram_val_output(encoded_text)
        grammar_value_logits = self._gram_val_output(batched_label_variants)
#         print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#         print( "ITLOG: grammar_value_logits.size() = {}".format(grammar_value_logits.size()) )
#         print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")
#        grammar_value_logits = grammar_value_logits.select(2, 0)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

#        lemma_logits = self._lemma_output(encoded_text)
        lemma_logits = self._lemma_output(batched_label_variants)
        predicted_lemmas = lemma_logits.argmax(-1)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            token_mask_3 = token_mask
            token_mask_3 = torch.unsqueeze(token_mask_3, 2)
            token_mask_3 = token_mask_3.repeat(1,1,3)            
#             print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#             print( "ITLOG: token_mask.size = {}".format(token_mask.size()) )
#             print( "ITLOG: token_mask_3.size = {}".format(token_mask_3.size()) )
#             print( "ITLOG: token_mask_3 = {}".format(token_mask_3) )
#             print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")            
            grammar_nll = self._update_multiclass_prediction_metrics_3(
                logits=grammar_value_logits, targets=grammar_values,
                mask=token_mask_3, accuracy_metric=self._gram_val_prediction_accuracy
            )

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            token_mask_3 = token_mask
            token_mask_3 = torch.unsqueeze(token_mask_3, 2)
            token_mask_3 = token_mask_3.repeat(1,1,3)            
            lemma_nll = self._update_multiclass_prediction_metrics_3(
                logits=lemma_logits, targets=lemma_indices,
                mask=token_mask_3, accuracy_metric=self._lemma_prediction_accuracy #, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX
            )

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
        }

        return output_dict