コード例 #1
0
ファイル: biaffine_dep_module.py プロジェクト: vrmpx/relogic
  def get_head_tags(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        head_indices: torch.Tensor) -> torch.Tensor:
    """
    Decodes the head tags given the head and child tag representations
    and a tensor of head indices to compute tags for. Note that these are
    either gold or predicted heads, depending on whether this function is
    being called to compute the loss, or if it's being called during inference.
    Parameters
    ----------
    head_tag_representation : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length, tag_representation_dim),
        which will be used to generate predictions for the dependency tags
        for the given arcs.
    child_tag_representation : ``torch.Tensor``, required
        A tensor of shape (batch_size, sequence_length, tag_representation_dim),
        which will be used to generate predictions for the dependency tags
        for the given arcs.
    head_indices : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length). The indices of the heads
        for every word.
    Returns
    -------
    head_tag_logits : ``torch.Tensor``
        A tensor of shape (batch_size, sequence_length, num_head_tags),
        representing logits for predicting a distribution over tags
        for each arc.
    """
    batch_size = head_tag_representation.size(0)
    # shape (batch_size,)
    range_vector = get_range_vector(
      batch_size, get_device_of(head_tag_representation)
    ).unsqueeze(1)

    # This next statement is quite a complex piece of indexing, which you really
    # need to read the docs to understand. See here:
    # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
    # In effect, we are selecting the indices corresponding to the heads of each word from the
    # sequence length dimension for each element in the batch.

    # shape (batch_size, sequence_length, tag_representation_dim)
    selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
    selected_head_tag_representations = selected_head_tag_representations.contiguous()
    # shape (batch_size, sequence_length, num_head_tags)
    head_tag_logits = self.tag_bilinear(
      selected_head_tag_representations, child_tag_representation
    )
    return head_tag_logits
コード例 #2
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            value_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.LongTensor = None,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # Shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        span_ends = span_ends - 1

        span_widths = span_ends - span_starts

        max_batch_span_width = span_widths.max().item() + 1

        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = utils.get_range_vector(
            max_batch_span_width, sequence_tensor.device).view(1, 1, -1)
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.relu(raw_span_indices.float()).long()

        flat_span_indices = utils.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        span_embeddings = utils.batched_index_select(value_tensor,
                                                     span_indices,
                                                     flat_span_indices)

        span_attention_logits = utils.batched_index_select(
            global_attention_logits, span_indices,
            flat_span_indices).squeeze(-1)

        span_attention_weights = utils.masked_softmax(span_attention_logits,
                                                      span_mask)

        attended_text_embeddings = utils.weighted_sum(span_embeddings,
                                                      span_attention_weights)

        if span_indices_mask is not None:
            return attended_text_embeddings * span_indices_mask.unsqueeze(
                -1).float()

        return attended_text_embeddings
コード例 #3
0
ファイル: biaffine_dep_module.py プロジェクト: vrmpx/relogic
  def construct_loss(
          self,
          head_tag_representation: torch.Tensor,
          child_tag_representation: torch.Tensor,
          attended_arcs: torch.Tensor,
          head_indices: torch.Tensor,
          head_tags: torch.Tensor,
          mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the arc and tag loss for a sequence given gold head indices and tags.
    Parameters
    ----------
    head_tag_representation : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length, tag_representation_dim),
        which will be used to generate predictions for the dependency tags
        for the given arcs.
    child_tag_representation : ``torch.Tensor``, required
        A tensor of shape (batch_size, sequence_length, tag_representation_dim),
        which will be used to generate predictions for the dependency tags
        for the given arcs.
    attended_arcs : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
        a distribution over attachments of a given word to all other words.
    head_indices : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length).
        The indices of the heads for every word.
    head_tags : ``torch.Tensor``, required.
        A tensor of shape (batch_size, sequence_length).
        The dependency labels of the heads for every word.
    mask : ``torch.Tensor``, required.
        A mask of shape (batch_size, sequence_length), denoting unpadded
        elements in the sequence.
    Returns
    -------
    arc_nll : ``torch.Tensor``, required.
        The negative log likelihood from the arc loss.
    tag_nll : ``torch.Tensor``, required.
        The negative log likelihood from the arc tag loss.
    """
    float_mask = mask.float()
    batch_size, sequence_length, _ = attended_arcs.size()
    # shape (batch_size, 1)
    range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
    # shape (batch_size, sequence_length, sequence_length)
    normalised_arc_logits = (
          masked_log_softmax(attended_arcs, mask)
          * float_mask.unsqueeze(2)
          * float_mask.unsqueeze(1)
    )

    # shape (batch_size, sequence_length, num_head_tags)
    head_tag_logits = self.get_head_tags(
      head_tag_representation, child_tag_representation, head_indices
    )
    normalised_head_tag_logits = masked_log_softmax(
      head_tag_logits, mask.unsqueeze(-1)
    ) * float_mask.unsqueeze(-1)
    # index matrix with shape (batch, sequence_length)
    timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
    child_index = (
      timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
    )
    # shape (batch_size, sequence_length)
    arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
    tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
    # We don't care about predictions for the symbolic ROOT token's head,
    # so we remove it from the loss.
    arc_loss = arc_loss[:, 1:]
    tag_loss = tag_loss[:, 1:]

    # The number of valid positions is equal to the number of unmasked elements minus
    # 1 per sequence in the batch, to account for the symbolic HEAD token.
    valid_positions = mask.sum() - batch_size

    arc_nll = -arc_loss.sum() / valid_positions.float()
    tag_nll = -tag_loss.sum() / valid_positions.float()
    return arc_nll, tag_nll