def get_head_tags( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation) ).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear( selected_head_tag_representations, child_tag_representation ) return head_tag_logits
def forward( self, sequence_tensor: torch.FloatTensor, value_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) span_ends = span_ends - 1 span_widths = span_ends - span_starts max_batch_span_width = span_widths.max().item() + 1 global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = utils.get_range_vector( max_batch_span_width, sequence_tensor.device).view(1, 1, -1) span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.relu(raw_span_indices.float()).long() flat_span_indices = utils.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) span_embeddings = utils.batched_index_select(value_tensor, span_indices, flat_span_indices) span_attention_logits = utils.batched_index_select( global_attention_logits, span_indices, flat_span_indices).squeeze(-1) span_attention_weights = utils.masked_softmax(span_attention_logits, span_mask) attended_text_embeddings = utils.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: return attended_text_embeddings * span_indices_mask.unsqueeze( -1).float() return attended_text_embeddings
def construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = ( masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.get_head_tags( head_tag_representation, child_tag_representation, head_indices ) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1) ) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = ( timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() ) # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll