def masked_topk( input_: torch.FloatTensor, mask: torch.BoolTensor, k: Union[int, torch.LongTensor], dim: int = -1, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: if input_.size() != mask.size(): raise ValueError("`input_` and `mask` must have the same shape.") if not -input_.dim() <= dim < input_.dim(): raise ValueError("`dim` must be in `[-input_.dim(), input_.dim())`") dim = (dim + input_.dim()) % input_.dim() max_k = k if isinstance(k, int) else k.max() permutation = list(range(input_.dim())) permutation.pop(dim) permutation += [dim] reverse_permutation = list(range(input_.dim() - 1)) reverse_permutation.insert(dim, -1) other_dims_size = list(input_.size()) other_dims_size.pop(dim) permuted_size = other_dims_size + [max_k] # for restoration if isinstance(k, int): k = k * torch.ones(*other_dims_size, dtype=torch.long, device=mask.device) else: if list(k.size()) != other_dims_size: raise ValueError( "`k` must have the same shape as `input_` with dimension `dim` removed." ) num_items = input_.size(dim) input_ = input_.permute(*permutation).reshape(-1, num_items) mask = mask.permute(*permutation).reshape(-1, num_items) k = k.reshape(-1) input_ = replace_masked_values(input_, mask, min_value_of_dtype(input_.dtype)) _, top_indices = input_.topk(max_k, 1) top_indices_mask = get_mask_from_sequence_lengths(k, max_k).bool() fill_value, _ = top_indices.max(dim=1, keepdim=True) top_indices = torch.where(top_indices_mask, top_indices, fill_value) top_indices, _ = top_indices.sort(1) sequence_mask = mask.gather(1, top_indices) top_mask = top_indices_mask & sequence_mask top_input = input_.gather(1, top_indices) return ( top_input.reshape(*permuted_size).permute(*reverse_permutation), top_mask.reshape(*permuted_size).permute(*reverse_permutation), top_indices.reshape(*permuted_size).permute(*reverse_permutation), )
def reset_states(self, mask: torch.BoolTensor = None) -> None: """ Resets the internal states of a stateful encoder. # Parameters mask : `torch.BoolTensor`, optional. A tensor of shape `(batch_size,)` indicating which states should be reset. If not provided, all states will be reset. """ if mask is None: self._states = None else: # state has shape (num_layers, batch_size, hidden_size). We reshape # mask to have shape (1, batch_size, 1) so that operations # broadcast properly. mask_batch_size = mask.size(0) mask = mask.view(1, mask_batch_size, 1) new_states = [] for old_state in self._states: old_state_batch_size = old_state.size(1) if old_state_batch_size != mask_batch_size: raise ValueError( f"Trying to reset states using mask with incorrect batch size. " f"Expected batch size: {old_state_batch_size}. " f"Provided batch size: {mask_batch_size}.") new_state = ~mask * old_state new_states.append(new_state.detach()) self._states = tuple(new_states)
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: """ # Parameters inputs : `torch.Tensor`, required. A Tensor of shape `(batch_size, sequence_length, hidden_size)`. mask : `torch.BoolTensor`, required. A binary mask of shape `(batch_size, sequence_length)` representing the non-padded elements in each sequence in the batch. # Returns A `torch.Tensor` of shape (num_layers, batch_size, sequence_length, hidden_size), where the num_layers dimension represents the LSTM output from that layer. """ batch_size, total_sequence_length = mask.size() stacked_sequence_output, final_states, restoration_indices = self.sort_and_run_forward( self._lstm_forward, inputs, mask) num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size( ) # Add back invalid rows which were removed in the call to sort_and_run_forward. if num_valid < batch_size: zeros = stacked_sequence_output.new_zeros(num_layers, batch_size - num_valid, returned_timesteps, encoder_dim) stacked_sequence_output = torch.cat( [stacked_sequence_output, zeros], 1) # The states also need to have invalid rows added back. new_states = [] for state in final_states: state_dim = state.size(-1) zeros = state.new_zeros(num_layers, batch_size - num_valid, state_dim) new_states.append(torch.cat([state, zeros], 1)) final_states = new_states # It's possible to need to pass sequences which are padded to longer than the # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking # the sequences mean that the returned tensor won't include these dimensions, because # the RNN did not need to process them. We add them back on in the form of zeros here. sequence_length_difference = total_sequence_length - returned_timesteps if sequence_length_difference > 0: zeros = stacked_sequence_output.new_zeros( num_layers, batch_size, sequence_length_difference, stacked_sequence_output[0].size(-1), ) stacked_sequence_output = torch.cat( [stacked_sequence_output, zeros], 2) self._update_states(final_states, restoration_indices) # Restore the original indices and return the sequence. # Has shape (num_layers, batch_size, sequence_length, hidden_size) return stacked_sequence_output.index_select(1, restoration_indices)
def _manipulate_mask(self, mask: torch.BoolTensor, student_scores: torch.Tensor, batch: Batch) -> torch.BoolTensor: """ Add one extra (masked-out) token to the mask, for compatibility with BART. """ assert student_scores.size(1) == batch.label_vec.size(1) + 1 mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1) return mask
def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf) ) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = ~mask.unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags( head_tag_representation, child_tag_representation, heads ) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor, hidden_state: torch.Tensor = None) -> torch.Tensor: if mask is None: # If a mask isn't passed, there is no padding in the batch of instances, so we can just # return the last sequence output as the state. This doesn't work in the case of # variable length sequences, as the last state for each element of the batch won't be # at the end of the max sequence length, so we have to use the state of the RNN below. return self._module(inputs, hidden_state)[0][:, -1, :] batch_size = mask.size(0) ( _, state, restoration_indices, ) = self.sort_and_run_forward(self._module, inputs, mask, hidden_state) # Deal with the fact the LSTM state is a tuple of (state, memory). if isinstance(state, tuple): state = state[0] num_layers_times_directions, num_valid, encoding_dim = state.size() # Add back invalid rows. if num_valid < batch_size: # batch size is the second dimension here, because pytorch # returns RNN state as a tensor of shape (num_layers * num_directions, # batch_size, hidden_size) zeros = state.new_zeros(num_layers_times_directions, batch_size - num_valid, encoding_dim) state = torch.cat([state, zeros], 1) # Restore the original indices and return the final state of the # top layer. Pytorch's recurrent layers return state in the form # (num_layers * num_directions, batch_size, hidden_size) regardless # of the 'batch_first' flag, so we transpose, extract the relevant # layer state (both forward and backward if using bidirectional layers) # and return them as a single (batch_size, self.get_output_dim()) tensor. # now of shape: (batch_size, num_layers * num_directions, hidden_size). unsorted_state = state.transpose(0, 1).index_select( 0, restoration_indices) # Extract the last hidden vector, including both forward and backward states # if the cell is bidirectional. Then reshape by concatenation (in the case # we have bidirectional states) or just squash the 1st dimension in the non- # bidirectional case. Return tensor has shape (batch_size, hidden_size * num_directions). try: last_state_index = 2 if self._module.bidirectional else 1 except AttributeError: last_state_index = 1 last_layer_state = unsorted_state[:, -last_state_index:, :] return last_layer_state.contiguous().view([-1, self.get_output_dim()])
def _manipulate_mask(self, mask: torch.BoolTensor, student_scores: torch.Tensor, batch: Batch) -> torch.BoolTensor: """ Add one extra (masked-out) token to the mask, for compatibility with BART. Only necessary when examining decoder outputs directly. """ if student_scores.size(1) == batch.label_vec.size(1) + 1: mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1) return mask
def _get_target_token_embeddings(self, token_embeddings: torch.Tensor, mask: torch.BoolTensor, direction: int) -> torch.Tensor: # Need to shift the mask in the correct direction zero_col = token_embeddings.new_zeros(mask.size(0), 1).to(dtype=torch.bool) if direction == 0: # forward direction, get token to right shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1) else: shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1) return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view( -1, self._forward_dim)
def batch_grad( func: Callable, inputs: FloatTensor, idx: Union[int, Tuple[int], List] = None, mask: BoolTensor = None, ) -> FloatTensor: """Compute gradients for a batch of inputs Args: func (Callable): inputs (FloatTensor): The first dimension corresponds the different instances. idx (Union[int, Tuple[int], List]): The index from the second dimension to the last. If a list is given, then the gradient of the sum of function values of these indices is computed for each instance. mask (BoolTensor): Returns: FloatTensor: The gradient for each input instance. """ assert torch.is_tensor(inputs) assert (idx is None) != ( mask is None), "Either idx or mask (and only one of them) has to be provided." inputs.requires_grad_() out = func(inputs) if idx is not None: if not isinstance(idx, list): idx = [idx] indices = [] for i in range(inputs.size(0)): for j in idx: j = (j, ) if isinstance(j, int) else j indices.append((i, ) + j) t = out[list(zip(*indices))].sum(-1) else: # [M, B, ...] out = out.view(-1, *mask.size()) t = out.masked_select(mask).sum() gradients = torch.autograd.grad(t, inputs)[0] return gradients
def _greedy_decode( arc_scores: torch.Tensor, arc_tag_logits: torch.Tensor, mask: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. # Parameters arc_scores : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. arc_tag_logits : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to generate a distribution over tags for each arc. mask : `torch.BoolTensor`, required. A mask of shape (batch_size, sequence_length). # Returns arc_probs : `torch.Tensor` A tensor of shape (batch_size, sequence_length, sequence_length) representing the probability of an arc being present for this edge. arc_tag_probs : `torch.Tensor` A tensor of shape (batch_size, sequence_length, sequence_length, sequence_length) representing the distribution over edge tags for a given edge. """ # Mask the diagonal, because we don't self edges. inf_diagonal_mask = torch.diag( arc_scores.new(mask.size(1)).fill_(-numpy.inf)) arc_scores = arc_scores + inf_diagonal_mask # shape (batch_size, sequence_length, sequence_length, num_tags) arc_tag_logits = arc_tag_logits + inf_diagonal_mask.unsqueeze( 0).unsqueeze(-1) # Mask padded tokens, because we only want to consider actual word -> word edges. minus_mask = ~mask.unsqueeze(2) arc_scores.masked_fill_(minus_mask, -numpy.inf) arc_tag_logits.masked_fill_(minus_mask.unsqueeze(-1), -numpy.inf) # shape (batch_size, sequence_length, sequence_length) arc_probs = arc_scores.sigmoid() # shape (batch_size, sequence_length, sequence_length, num_tags) arc_tag_probs = torch.nn.functional.softmax(arc_tag_logits, dim=-1) return arc_probs, arc_tag_probs
def get_attention_masks(self, mask: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns 2 masks of shape (batch_size, timesteps, timesteps) representing 1) non-padded elements, and 2) elements of the sequence which are permitted to be involved in attention at a given timestep. """ device = mask.device # Forward case: timesteps = mask.size(1) # Shape (1, timesteps, timesteps) subsequent = subsequent_mask(timesteps, device) # Broadcasted logical and - we want zero # elements where either we have padding from the mask, # or we aren't allowed to use the timesteps. # Shape (batch_size, timesteps, timesteps) forward_mask = mask.unsqueeze(-1) & subsequent # Backward case - exactly the same, but transposed. backward_mask = forward_mask.transpose(1, 2) return forward_mask, backward_mask
def forward( # type: ignore self, tokens: TextFieldTensors, mask_positions: torch.BoolTensor, target_ids: TextFieldTensors = None, ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` The output of `TextField.as_tensor()` for a batch of sentences. mask_positions : `torch.LongTensor` The positions in `tokens` that correspond to [MASK] tokens that we should try to fill in. Shape should be (batch_size, num_masks). target_ids : `TextFieldTensors` This is a list of token ids that correspond to the mask positions we're trying to fill. It is the output of a `TextField`, purely for convenience, so we can handle wordpiece tokenizers and such without having to do crazy things in the dataset reader. We assume that there is exactly one entry in the dictionary, and that it has a shape identical to `mask_positions` - one target token per mask position. """ targets = None if target_ids is not None: targets = util.get_token_ids_from_text_field_tensors(target_ids) mask_positions = mask_positions.squeeze(-1) batch_size, num_masks = mask_positions.size() if targets is not None and targets.size() != mask_positions.size(): raise ValueError( f"Number of targets ({targets.size()}) and number of masks " f"({mask_positions.size()}) are not equal") # Shape: (batch_size, num_tokens, embedding_dim) embeddings = self._text_field_embedder(tokens) # Shape: (batch_size, num_tokens, encoding_dim) if self._contextualizer: mask = util.get_text_field_mask(embeddings) contextual_embeddings = self._contextualizer(embeddings, mask) else: contextual_embeddings = embeddings # Does advanced indexing to get the embeddings of just the mask positions, which is what # we're trying to predict. batch_index = torch.arange(0, batch_size).long().unsqueeze(1) mask_embeddings = contextual_embeddings[batch_index, mask_positions] target_logits = self._language_model_head( self._dropout(mask_embeddings)) vocab_size = target_logits.size(-1) probs = torch.nn.functional.softmax(target_logits, dim=-1) k = min(vocab_size, 5) # min here largely because tests use small vocab top_probs, top_indices = probs.topk(k=k, dim=-1) output_dict = {"probabilities": top_probs, "top_indices": top_indices} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if targets is not None: target_logits = target_logits.view(batch_size * num_masks, vocab_size) targets = targets.view(batch_size * num_masks) loss = torch.nn.functional.cross_entropy(target_logits, targets) self._perplexity(loss) output_dict["loss"] = loss return output_dict
def _get_and_record_component_attention_loss( self, teacher_attention_matrices: List[Dict[str, torch.Tensor]], student_attention_matrices: List[Dict[str, torch.Tensor]], mask: torch.BoolTensor, tokens_per_example: torch.Tensor, num_tokens: torch.Tensor, mapped_layers: List[int], attn_type: str, metric_name: str, ) -> torch.Tensor: """ Calculate the given attention loss and register it as the given metric name. """ assert isinstance(self, TorchGeneratorAgent) # Code relies on methods # Select the right attention matrices selected_student_attn_matrices = [ layer_matrices[attn_type] for layer_matrices in student_attention_matrices ] selected_teacher_attn_matrices = [ layer_matrices[attn_type] for layer_matrices in teacher_attention_matrices ] batch_size = mask.size(0) per_layer_losses = [] per_layer_per_example_losses = [] for student_layer_idx, teacher_layer_idx in enumerate(mapped_layers): raw_layer_loss = F.mse_loss( input=selected_student_attn_matrices[student_layer_idx], target=selected_teacher_attn_matrices[teacher_layer_idx], reduction='none', ) clamped_layer_loss = torch.clamp(raw_layer_loss, min=0, max=NEAR_INF_FP16) # Prevent infs from appearing in the loss term. Especially important with # fp16 reshaped_layer_loss = clamped_layer_loss.view( batch_size, -1, clamped_layer_loss.size(-2), clamped_layer_loss.size(-1)) # [batch size, n heads, query length, key length] mean_layer_loss = reshaped_layer_loss.mean(dim=(1, 3)) # Take the mean over the attention heads and the key length assert mean_layer_loss.shape == mask.shape masked_layer_loss = mean_layer_loss * mask layer_loss_per_example = masked_layer_loss.sum( dim=-1) # Sum over token dim layer_loss = masked_layer_loss.div(num_tokens).sum() # Divide before summing over examples so that values don't get too large per_layer_losses.append(layer_loss) per_layer_per_example_losses.append(layer_loss_per_example) attn_loss = torch.stack(per_layer_losses).mean() attn_loss_per_example = torch.stack(per_layer_per_example_losses, dim=1).mean(dim=1) # Record metric self.record_local_metric( metric_name, AverageMetric.many(attn_loss_per_example, tokens_per_example)) return attn_loss
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, parent_mask: torch.BoolTensor, parent_start_mask: torch.BoolTensor, parent_end_mask: torch.BoolTensor, child_mask: torch.BoolTensor = None, parent_idxs: torch.LongTensor = None, parent_tags: torch.LongTensor = None, parent_starts: torch.BoolTensor = None, parent_ends: torch.BoolTensor = None, child_idxs: torch.BoolTensor = None, child_starts: torch.BoolTensor = None, child_ends: torch.BoolTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] parent_mask: [batch_size, num_words] parent_start_mask: [batch_size, num_words] parent_end_mask: [batch_size, num_words] child_mask: [batch_size, num_words] parent_idxs: [batch_size] parent_tags: [batch_size] parent_starts: [batch_size] parent_ends: [batch_size] child_idxs: [batch_size, num_words] child_starts: [batch_size, num_words] child_ends: [batch_size, num_words] Returns: parent_probs: [batch_size, num_words] parent_tag_probs: [batch_size, num_words, num_tags] parent_start_probs: [batch_size, num_words] parent_end_probs: [batch_size, num_words] child_probs: [batch_size, num_words] child_start_probs: [batch_size, num_words] child_end_probs: [batch_size, num_words] arc_loss (if parent_idx is not None) tag_loss (if parent_idxs and parent_tags are not None) start_loss (if parent_starts is not None) end_loss (if parent_ends is not None) child_loss (if child_idxs is not None) child_start_loss (if child_starts is not None) child_end_loss (if child_ends is not None) """ cls_embedding, embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": # bert = self.bert if self.arch == "bert" else self.roberta extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) parent_start_scores = self.parent_start_feedforward( encoded_text).squeeze(-1) parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze( -1) # mask out impossible positions minus_inf = -1e8 parent_mask = torch.logical_and(parent_mask, word_mask) parent_scores = parent_scores + (~parent_mask).float() * minus_inf parent_start_mask = torch.logical_and(parent_start_mask, word_mask) parent_start_scores = parent_start_scores + ( ~parent_start_mask).float() * minus_inf parent_end_mask = torch.logical_and(parent_end_mask, word_mask) parent_end_scores = parent_end_scores + ( ~parent_end_mask).float() * minus_inf parent_probs = F.softmax(parent_scores, dim=-1) parent_start_probs = F.softmax(parent_start_scores, dim=-1) parent_end_probs = F.softmax(parent_end_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) output = (parent_probs, parent_tag_probs, parent_start_probs, parent_end_probs) if self.config.predict_child: child_scores = self.child_feedforward(encoded_text).squeeze(-1) child_start_scores = self.child_start_feedforward( encoded_text).squeeze(-1) child_end_scores = self.child_end_feedforward( encoded_text).squeeze(-1) # todo add child mask - child should be inside the origin span if child_mask is None: child_mask = torch.ones_like(word_mask) else: child_mask = torch.logical_and(child_mask, word_mask) child_scores = child_scores + (~child_mask).float() * minus_inf child_start_scores = child_start_scores + ( ~child_mask).float() * minus_inf child_end_scores = child_end_scores + ( ~child_mask).float() * minus_inf child_probs = torch.sigmoid(child_scores) child_start_probs = torch.sigmoid(child_start_scores) child_end_probs = torch.sigmoid(child_end_scores) output = output + (child_probs, child_start_probs, child_end_probs) # add losses batch_range_vector = get_range_vector( batch_size, get_device_of(encoded_text)) # [bsz] if parent_idxs is not None: # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs] parent_arc_nll = parent_arc_nll.mean() output = output + (parent_arc_nll, ) if parent_tags is not None: parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, parent_idxs], parent_tags) output = output + (parent_tag_nll, ) if parent_starts is not None: # [bsz, seq_len] parent_start_logits = F.log_softmax(parent_start_scores, dim=-1) parent_start_nll = -parent_start_logits[batch_range_vector, parent_starts].mean() output = output + (parent_start_nll, ) if parent_ends is not None: # [bsz, seq_len] parent_end_logits = F.log_softmax(parent_end_scores, dim=-1) parent_end_nll = -parent_end_logits[batch_range_vector, parent_ends].mean() output = output + (parent_end_nll, ) if self.config.predict_child: if child_idxs is not None: child_loss = F.binary_cross_entropy_with_logits( child_scores, child_idxs.float(), reduction="none") child_loss = (child_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_loss, ) if child_starts is not None: child_start_loss = F.binary_cross_entropy_with_logits( child_start_scores, child_starts.float(), reduction="none") child_start_loss = (child_start_loss * child_mask).sum() / ( child_mask.sum() + 1e-8) output = output + (child_start_loss, ) if child_ends is not None: child_end_loss = F.binary_cross_entropy_with_logits( child_end_scores, child_ends.float(), reduction="none") child_end_loss = (child_end_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_end_loss, ) return output
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, dep_idxs: torch.LongTensor, dep_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, ): embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._input_dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask], 1) dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1) dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1) encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = ~word_mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=dep_idxs, head_tags=dep_tags, mask=word_mask, ) return predicted_heads, predicted_head_tags, arc_nll, tag_nll
def _unfold_long_sequences( self, embeddings: torch.FloatTensor, mask: torch.BoolTensor, batch_size: int, num_segment_concat_wordpieces: int, ) -> torch.FloatTensor: """ We take 2D segments of a long sequence and flatten them out to get the whole sequence representation while remove unnecessary special tokens. [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ] -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ] We truncate the start and end tokens for all segments, recombine the segments, and manually add back the start and end tokens. # Parameters embeddings: `torch.FloatTensor` Shape: [batch_size * num_segments, self._max_length, embedding_size]. mask: `torch.BoolTensor` Shape: [batch_size * num_segments, self._max_length]. The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` in `forward()`. batch_size: `int` num_segment_concat_wordpieces: `int` The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e. the original `token_ids.size(1)`. # Returns: embeddings: `torch.FloatTensor` Shape: [batch_size, self._num_wordpieces, embedding_size]. """ def lengths_to_mask(lengths, max_len, device): return torch.arange(max_len, device=device).expand( lengths.size(0), max_len) < lengths.unsqueeze(1) device = embeddings.device num_segments = int(embeddings.size(0) / batch_size) embedding_size = embeddings.size(2) # We want to remove all segment-level special tokens but maintain sequence-level ones num_wordpieces = num_segment_concat_wordpieces - ( num_segments - 1) * self._num_added_tokens embeddings = embeddings.reshape(batch_size, num_segments * self._max_length, embedding_size) mask = mask.reshape(batch_size, num_segments * self._max_length) # We assume that all 1s in the mask precede all 0s, and add an assert for that. # Open an issue on GitHub if this breaks for you. # Shape: (batch_size,) seq_lengths = mask.sum(-1) if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all(): raise ValueError( "Long sequence splitting only supports masks with all 1s preceding all 0s." ) # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op end_token_indices = ( seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1) # Shape: (batch_size, self._num_added_start_tokens, embedding_size) start_token_embeddings = embeddings[:, :self. _num_added_start_tokens, :] # Shape: (batch_size, self._num_added_end_tokens, embedding_size) end_token_embeddings = batched_index_select(embeddings, end_token_indices) embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size) embeddings = embeddings[:, :, self._num_added_start_tokens:-self. _num_added_end_tokens, :] # truncate segment-level start/end tokens embeddings = embeddings.reshape(batch_size, -1, embedding_size) # flatten # Now try to put end token embeddings back which is a little tricky. # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation. # Shape: (batch_size,) num_effective_segments = (seq_lengths + self._max_length - 1) / self._max_length # The number of indices that end tokens should shift back. num_removed_non_end_tokens = ( num_effective_segments * self._num_added_tokens - self._num_added_end_tokens) # Shape: (batch_size, self._num_added_end_tokens) end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1) assert (end_token_indices >= self._num_added_start_tokens).all() # Add space for end embeddings embeddings = torch.cat( [embeddings, torch.zeros_like(end_token_embeddings)], 1) # Add end token embeddings back embeddings.scatter_( 1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings) # Now put back start tokens. We can do this before putting back end tokens, but then # we need to change `num_removed_non_end_tokens` a little. embeddings = torch.cat([start_token_embeddings, embeddings], 1) # Truncate to original length embeddings = embeddings[:, :num_wordpieces, :] return embeddings
def sort_and_run_forward( self, module: Callable[[PackedSequence, Optional[RnnState]], Tuple[Union[PackedSequence, torch.Tensor], RnnState], ], inputs: torch.Tensor, mask: torch.BoolTensor, hidden_state: Optional[RnnState] = None, ): """ This function exists because Pytorch RNNs require that their inputs be sorted before being passed as input. As all of our Seq2xxxEncoders use this functionality, it is provided in a base class. This method can be called on any module which takes as input a `PackedSequence` and some `hidden_state`, which can either be a tuple of tensors or a tensor. As all of our Seq2xxxEncoders have different return types, we return `sorted` outputs from the module, which is called directly. Additionally, we return the indices into the batch dimension required to restore the tensor to it's correct, unsorted order and the number of valid batch elements (i.e the number of elements in the batch which are not completely masked). This un-sorting and re-padding of the module outputs is left to the subclasses because their outputs have different types and handling them smoothly here is difficult. # Parameters module : `Callable[RnnInputs, RnnOutputs]` A function to run on the inputs, where `RnnInputs: [PackedSequence, Optional[RnnState]]` and `RnnOutputs: Tuple[Union[PackedSequence, torch.Tensor], RnnState]`. In most cases, this is a `torch.nn.Module`. inputs : `torch.Tensor`, required. A tensor of shape `(batch_size, sequence_length, embedding_size)` representing the inputs to the Encoder. mask : `torch.BoolTensor`, required. A tensor of shape `(batch_size, sequence_length)`, representing masked and non-masked elements of the sequence for each element in the batch. hidden_state : `Optional[RnnState]`, (default = `None`). A single tensor of shape (num_layers, batch_size, hidden_size) representing the state of an RNN with or a tuple of tensors of shapes (num_layers, batch_size, hidden_size) and (num_layers, batch_size, memory_size), representing the hidden state and memory state of an LSTM-like RNN. # Returns module_output : `Union[torch.Tensor, PackedSequence]`. A Tensor or PackedSequence representing the output of the Pytorch Module. The batch size dimension will be equal to `num_valid`, as sequences of zero length are clipped off before the module is called, as Pytorch cannot handle zero length sequences. final_states : `Optional[RnnState]` A Tensor representing the hidden state of the Pytorch Module. This can either be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in the case of a GRU, or a tuple of tensors, such as those required for an LSTM. restoration_indices : `torch.LongTensor` A tensor of shape `(batch_size,)`, describing the re-indexing required to transform the outputs back to their original batch order. """ # In some circumstances you may have sequences of zero length. `pack_padded_sequence` # requires all sequence lengths to be > 0, so remove sequences of zero length before # calling self._module, then fill with zeros. # First count how many sequences are empty. batch_size = mask.size(0) num_valid = torch.sum(mask[:, 0]).int().item() sequence_lengths = get_lengths_from_binary_sequence_mask(mask) ( sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices, ) = sort_batch_by_length(inputs, sequence_lengths) # Now create a PackedSequence with only the non-empty, sorted sequences. packed_sequence_input = pack_padded_sequence( sorted_inputs[:num_valid, :, :], sorted_sequence_lengths[:num_valid].data.tolist(), batch_first=True, ) # Prepare the initial states. if not self.stateful: if hidden_state is None: initial_states: Any = hidden_state elif isinstance(hidden_state, tuple): initial_states = [ state.index_select( 1, sorting_indices)[:, :num_valid, :].contiguous() for state in hidden_state ] else: initial_states = hidden_state.index_select( 1, sorting_indices)[:, :num_valid, :].contiguous() else: initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices) # Actually call the module on the sorted PackedSequence. module_output, final_states = module(packed_sequence_input, initial_states) return module_output, final_states, restoration_indices
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, subtree_spans: torch.LongTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] subtree_spans: [batch_size, num_words, 2] Returns: span_start_logits: [batch_size, num_words, num_words] span_end_logits: [batch_size, num_words, num_words] span_loss: if subtree_spans is given. """ # [bsz, seq_len, hidden] embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # [bsz, seq_len, dim] subtree_start_representation = self._dropout( self.subtree_start_feedforward(encoded_text)) subtree_end_representation = self._dropout( self.subtree_end_feedforward(encoded_text)) # [bsz, seq_len, seq_len] span_start_scores = self.subtree_start_attention( subtree_start_representation, subtree_start_representation) span_end_scores = self.subtree_end_attention( subtree_end_representation, subtree_end_representation) # start of word should less equal to it start_mask = word_mask.unsqueeze(-1) & ( ~torch.triu(span_start_scores.bool(), 1)) # end of word should greater equal to it. end_mask = word_mask.unsqueeze(-1) & torch.triu(span_end_scores.bool()) minus_inf = -1e8 span_start_scores = span_start_scores + ( ~start_mask).float() * minus_inf span_end_scores = span_end_scores + (~end_mask).float() * minus_inf output = (F.log_softmax(span_start_scores, dim=-1), F.log_softmax(span_end_scores, dim=-1)) if subtree_spans is not None: start_loss = F.cross_entropy( span_start_scores.view(batch_size * seq_len, -1), subtree_spans[:, :, 0].view(-1)) end_loss = F.cross_entropy( span_end_scores.view(batch_size * seq_len, -1), subtree_spans[:, :, 1].view(-1)) span_loss = start_loss + end_loss output = output + (span_loss, ) return output
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, span_idx: torch.LongTensor, span_tag: torch.LongTensor, child_arcs: torch.LongTensor, child_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, mrc_mask: torch.BoolTensor, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] span_idx: [batch_size, 2] span_tag: [batch_size, 1] child_arcs: [batch_size, num_words] child_tags: [batch_size, num_words] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] mrc_mask: [batch_size, num_words] Returns: parent_probs: [batch_size, num_word] parent_tag_probs: [batch_size, num_words] arc_nll: [1] tag_nll: [1] """ embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) # [bsz, seq_len, tag_classes] child_tag_scores = self.child_tag_feedforward(encoded_text) # [bsz, seq_len] child_scores = self.child_feedforward(encoded_text).squeeze(-1) # todo support cases that span_idx and span_tag are None # [bsz] batch_range_vector = get_range_vector(batch_size, get_device_of(encoded_text)) # [bsz] gold_positions = span_idx[:, 0] # compute parent arc loss minus_inf = -1e8 mrc_mask = torch.logical_and(mrc_mask, word_mask) parent_scores = parent_scores + (~mrc_mask).float() * minus_inf child_scores = child_scores + (~mrc_mask).float() * minus_inf # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, gold_positions].mean() # compute parent tag loss parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, gold_positions], span_tag) parent_probs = F.softmax(parent_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) child_probs = F.sigmoid(child_scores) child_tag_probs = F.softmax(child_tag_scores, dim=-1) child_arc_loss = F.binary_cross_entropy_with_logits(child_scores, child_arcs.float(), reduction="none") child_arc_loss = (child_arc_loss * mrc_mask.float()).sum() / mrc_mask.float().sum() child_tag_loss = F.cross_entropy(child_tag_scores.view( batch_size * seq_len, -1), child_tags.view(-1), reduction="none") child_tag_loss = (child_tag_loss * child_arcs.float().view(-1) ).sum() / (child_arcs.float().sum() + 1e-8) return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor, hidden_state: torch.Tensor = None) -> torch.Tensor: if self.stateful and mask is None: raise ValueError("Always pass a mask with stateful RNNs.") if self.stateful and hidden_state is not None: raise ValueError( "Stateful RNNs provide their own initial hidden_state.") if mask is None: return self._module(inputs, hidden_state)[0] batch_size, total_sequence_length = mask.size() packed_sequence_output, final_states, restoration_indices = self.sort_and_run_forward( self._module, inputs, mask, hidden_state) unpacked_sequence_tensor, _ = pad_packed_sequence( packed_sequence_output, batch_first=True) num_valid = unpacked_sequence_tensor.size(0) # Some RNNs (GRUs) only return one state as a Tensor. Others (LSTMs) return two. # If one state, use a single element list to handle in a consistent manner below. if not isinstance(final_states, (list, tuple)) and self.stateful: final_states = [final_states] # Add back invalid rows. if num_valid < batch_size: _, length, output_dim = unpacked_sequence_tensor.size() zeros = unpacked_sequence_tensor.new_zeros(batch_size - num_valid, length, output_dim) unpacked_sequence_tensor = torch.cat( [unpacked_sequence_tensor, zeros], 0) # The states also need to have invalid rows added back. if self.stateful: new_states = [] for state in final_states: num_layers, _, state_dim = state.size() zeros = state.new_zeros(num_layers, batch_size - num_valid, state_dim) new_states.append(torch.cat([state, zeros], 1)) final_states = new_states # It's possible to need to pass sequences which are padded to longer than the # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking # the sequences mean that the returned tensor won't include these dimensions, because # the RNN did not need to process them. We add them back on in the form of zeros here. sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size( 1) if sequence_length_difference > 0: zeros = unpacked_sequence_tensor.new_zeros( batch_size, sequence_length_difference, unpacked_sequence_tensor.size(-1)) unpacked_sequence_tensor = torch.cat( [unpacked_sequence_tensor, zeros], 1) if self.stateful: self._update_states(final_states, restoration_indices) # Restore the original indices and return the sequence. return unpacked_sequence_tensor.index_select(0, restoration_indices)
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, mrc_mask: torch.BoolTensor, parent_idxs: torch.LongTensor = None, parent_tags: torch.LongTensor = None, # is_subtree: torch.BoolTensor = None ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] mrc_mask: [batch_size, num_words] parent_idxs: [batch_size] parent_tags: [batch_size] # is_subtree: [batch_size] Returns: # is_subtree_probs: [batch_size] parent_probs: [batch_size, num_word] parent_tag_probs: [batch_size, num_words, num_tags] # subtree_loss(if is_subtree is not None) arc_loss (if parent_idx is not None) tag_loss (if parent_idxs and parent_tags are not None) """ cls_embedding, embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) cls_embedding = self._dropout(cls_embedding) # [bsz] # subtree_scores = self.is_subtree_feedforward(cls_embedding).squeeze(-1) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) # mask out impossible positions minus_inf = -1e8 mrc_mask = torch.logical_and(mrc_mask, word_mask) parent_scores = parent_scores + (~mrc_mask).float() * minus_inf parent_probs = F.softmax(parent_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) # output = (torch.sigmoid(subtree_scores), parent_probs, parent_tag_probs) # todo check if log in dp evaluation output = (parent_probs, parent_tag_probs ) # todo check if log in dp evaluation # add losses # if is_subtree is not None: # subtree_loss = F.binary_cross_entropy_with_logits(subtree_scores, is_subtree.float()) # output = output + (subtree_loss, ) # else: is_subtree = torch.ones_like(parent_tags).bool() if parent_idxs is not None: sample_mask = is_subtree.float() # [bsz] batch_range_vector = get_range_vector(batch_size, get_device_of(encoded_text)) # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs] parent_arc_nll = (parent_arc_nll * sample_mask).sum() / (sample_mask.sum() + 1e-8) output = output + (parent_arc_nll, ) if parent_tags is not None: parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, parent_idxs], parent_tags, reduction="none") parent_tag_nll = (parent_tag_nll * sample_mask).sum() / ( sample_mask.sum() + 1e-8) output = output + (parent_tag_nll, ) return output