def _get_pred_mask(self, output_symbols: torch.LongTensor) -> torch.BoolTensor: batch_size, max_pred_len = output_symbols.size() pred_mask = output_symbols.new_ones((batch_size, max_pred_len), dtype=torch.bool) for i in range(1, max_pred_len): pred_mask[:, i] = pred_mask[:, i - 1] * ~(output_symbols[:, i - 1] == self.eos_id) return pred_mask
def __call__(self, tokens: torch.LongTensor, prefix_mask: torch.LongTensor): padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool) for pad in self.excludes: padding_mask &= (tokens != pad) padding_mask &= prefix_mask # Only mask prefixes since the others won't be attended # Create a uniformly random mask selecting either the original words or OOV tokens dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < self.mask_prob) oov_mask = dropout_mask & padding_mask oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(self.oov) result = torch.where(oov_mask, oov_fill, tokens) return result, oov_mask
def contains(self, batch: torch.LongTensor) -> torch.BoolTensor: """ Check whether a triple is contained. :param batch: shape (batch_size, 3) The batch of triples. :return: shape: (batch_size,) The result. False guarantees that the element was not contained in the indexed triples. True can be erroneous. """ result = batch.new_ones(batch.shape[0], dtype=torch.bool) for i in self.probe(batch): result &= self.bit_array[i] return result
def token_dropout(tokens: torch.LongTensor, oov_token: int, exclude_tokens: List[int], p: float = 0.2, training: float = True) -> torch.LongTensor: """During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p`` Adopted from https://github.com/Hyperparticle/udify Args: tokens: The current batch of padded sentences with word ids oov_token: The mask token exclude_tokens: The tokens for padding the input batch p: The probability a word gets mapped to the unknown token training: Applies the dropout if set to ``True`` tokens: torch.LongTensor: oov_token: int: exclude_tokens: List[int]: p: float: (Default value = 0.2) training: float: (Default value = True) Returns: A copy of the input batch with token dropout applied """ if training and p > 0: # This creates a mask that only considers unpadded tokens for mapping to oov padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool) for pad in exclude_tokens: padding_mask &= (tokens != pad) # Create a uniformly random mask selecting either the original words or OOV tokens dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < p) oov_mask = dropout_mask & padding_mask oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(oov_token) result = torch.where(oov_mask, oov_fill, tokens) return result else: return tokens
def _action_to_token(self, action_tokens: torch.LongTensor, draft_tokens: torch.LongTensor) -> torch.LongTensor: predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1)) draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1)) predicted_tokens = action_tokens.new_full((action_tokens.size()), self.END) for act_step in action_tokens.t(): # KEEP, DELETE, COPY, ADD (other) keep_mask = act_step == self.KEEP drop_mask = act_step == self.DROP add_mask = ~(keep_mask | drop_mask) predicted_tokens.scatter_(1, predicted_pointer, draft_tokens.gather(1, draft_pointer)) predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter( 1, predicted_pointer[add_mask], act_step[add_mask].unsqueeze(1)) draft_pointer[keep_mask | drop_mask] += 1 predicted_pointer[~drop_mask] += 1 return predicted_tokens
def _parse( self, encoded_text: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, ) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, ) return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
def forward(self, # type: ignore # words: Dict[str, torch.LongTensor], encoded_text: torch.FloatTensor, mask: torch.LongTensor, pos_logits: torch.LongTensor = None, # predicted head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, _, _ = encoded_text.size() pos_tags = None if pos_logits is not None and self.pos_tag_embedding is not None: # Embed the predicted POS tags and concatenate the embeddings to the input num_pos_classes = pos_logits.size(-1) pos_logits = pos_logits.view(-1, num_pos_classes) _, pos_tags = pos_logits.max(-1) pos_embed_size = self.pos_tag_embedding.get_output_dim() embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags)) embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size) encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1) encoded_text = self.encoder(encoded_text, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], # "pos": [meta["pos"] for meta in metadata] } return output_dict
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ #l_ext_input = encoded_text l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1) lemma_logits = self._lemma_output(l_ext_input) predicted_lemmas = lemma_logits.argmax(-1) # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ lemma_probs = torch.nn.functional.softmax(lemma_logits, -1) top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt] #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt] top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices) #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices) # АНАЛОГИЧНО ДЛЯ ГРАММЕМ gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1) top_gramms_indices = ( -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt] top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: synt_prediction, benrg = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) predicted_heads, predicted_head_tags = synt_prediction # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute( 0, 2, 1) # склеивает тип синт. отношения с индексом родителя top_deprels_indices = (-benrgf).argsort( -1)[:, :, :self.TopNCnt] # отбираем наилучшие комбинации top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices) seqlen = benrg.shape[2] top_heads = torch.fmod(top_deprels_indices, seqlen) top_deprels_indices = torch.div(top_deprels_indices, seqlen) # torch.floor не срабатывает if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, "top_lemmas": top_lemmas_indices, "top_lemmas_prob": top_lemmas_prob, "top_gramms": top_gramms_indices, "top_gramms_prob": top_gramms_prob, "top_heads": top_heads, "top_deprels": top_deprels_indices, "top_deprels_prob": top_deprels_prob, } return output_dict
def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor: valid_tokens_mask = tensor.new_ones(tensor.size(), dtype=torch.uint8) for index in self._ignored_indices: valid_tokens_mask = valid_tokens_mask & (tensor != index).byte() return valid_tokens_mask
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) lemma_logits = self._lemma_output(encoded_text) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) # добавим измеремение, которое каждому выходу энкодера ставит в соответствие три его копии encoded_text_3 = encoded_text encoded_text_3 = torch.unsqueeze(encoded_text_3, 2) encoded_text_3 = encoded_text_3.repeat(1,1,3,1) # пропустим три копии вектора (с выхода энкодера) через lstm seq_len = encoded_text.size()[1] emb_div_val = encoded_text.size()[2] multi_triplets = torch.reshape(encoded_text_3, (-1, 3, emb_div_val)) label_variants, _ = self.multilabeler_lstm(multi_triplets) batched_label_variants = torch.reshape(label_variants, (-1, seq_len, 3, emb_div_val)) # # отладочный вывод # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: encoded_text.size() = {}".format(encoded_text.size()) ) # print( "ITLOG: encoded_text_3.size() = {}".format(encoded_text_3.size()) ) # print( "ITLOG: multi_triplets.size() = {}".format(multi_triplets.size()) ) # print( "ITLOG: label_variants.size() = {}".format(label_variants.size()) ) # print( "ITLOG: batched_label_variants.size() = {}".format(batched_label_variants.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = self._gram_val_output(encoded_text) grammar_value_logits = self._gram_val_output(batched_label_variants) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: grammar_value_logits.size() = {}".format(grammar_value_logits.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = grammar_value_logits.select(2, 0) predicted_gram_vals = grammar_value_logits.argmax(-1) # lemma_logits = self._lemma_output(encoded_text) lemma_logits = self._lemma_output(batched_label_variants) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: token_mask.size = {}".format(token_mask.size()) ) # print( "ITLOG: token_mask_3.size = {}".format(token_mask_3.size()) ) # print( "ITLOG: token_mask_3 = {}".format(token_mask_3) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") grammar_nll = self._update_multiclass_prediction_metrics_3( logits=grammar_value_logits, targets=grammar_values, mask=token_mask_3, accuracy_metric=self._gram_val_prediction_accuracy ) lemma_nll = torch.tensor(0.) if lemma_indices is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) lemma_nll = self._update_multiclass_prediction_metrics_3( logits=lemma_logits, targets=lemma_indices, mask=token_mask_3, accuracy_metric=self._lemma_prediction_accuracy #, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict