Пример #1
0
def masked_topk(
    input_: torch.FloatTensor,
    mask: torch.BoolTensor,
    k: Union[int, torch.LongTensor],
    dim: int = -1,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]:
    if input_.size() != mask.size():
        raise ValueError("`input_` and `mask` must have the same shape.")
    if not -input_.dim() <= dim < input_.dim():
        raise ValueError("`dim` must be in `[-input_.dim(), input_.dim())`")
    dim = (dim + input_.dim()) % input_.dim()

    max_k = k if isinstance(k, int) else k.max()


    permutation = list(range(input_.dim()))
    permutation.pop(dim)
    permutation += [dim]

    reverse_permutation = list(range(input_.dim() - 1))
    reverse_permutation.insert(dim, -1)

    other_dims_size = list(input_.size())
    other_dims_size.pop(dim)
    permuted_size = other_dims_size + [max_k]  # for restoration

    if isinstance(k, int):
        k = k * torch.ones(*other_dims_size, dtype=torch.long, device=mask.device)
    else:
        if list(k.size()) != other_dims_size:
            raise ValueError(
                "`k` must have the same shape as `input_` with dimension `dim` removed."
            )

    num_items = input_.size(dim)
    input_ = input_.permute(*permutation).reshape(-1, num_items)
    mask = mask.permute(*permutation).reshape(-1, num_items)
    k = k.reshape(-1)

    input_ = replace_masked_values(input_, mask, min_value_of_dtype(input_.dtype))

    _, top_indices = input_.topk(max_k, 1)

    top_indices_mask = get_mask_from_sequence_lengths(k, max_k).bool()

    fill_value, _ = top_indices.max(dim=1, keepdim=True)
    top_indices = torch.where(top_indices_mask, top_indices, fill_value)

    top_indices, _ = top_indices.sort(1)

    sequence_mask = mask.gather(1, top_indices)
    top_mask = top_indices_mask & sequence_mask

    top_input = input_.gather(1, top_indices)

    return (
        top_input.reshape(*permuted_size).permute(*reverse_permutation),
        top_mask.reshape(*permuted_size).permute(*reverse_permutation),
        top_indices.reshape(*permuted_size).permute(*reverse_permutation),
    )
Пример #2
0
    def reset_states(self, mask: torch.BoolTensor = None) -> None:
        """
        Resets the internal states of a stateful encoder.

        # Parameters

        mask : `torch.BoolTensor`, optional.
            A tensor of shape `(batch_size,)` indicating which states should
            be reset. If not provided, all states will be reset.
        """
        if mask is None:
            self._states = None
        else:
            # state has shape (num_layers, batch_size, hidden_size). We reshape
            # mask to have shape (1, batch_size, 1) so that operations
            # broadcast properly.
            mask_batch_size = mask.size(0)
            mask = mask.view(1, mask_batch_size, 1)
            new_states = []
            for old_state in self._states:
                old_state_batch_size = old_state.size(1)
                if old_state_batch_size != mask_batch_size:
                    raise ValueError(
                        f"Trying to reset states using mask with incorrect batch size. "
                        f"Expected batch size: {old_state_batch_size}. "
                        f"Provided batch size: {mask_batch_size}.")
                new_state = ~mask * old_state
                new_states.append(new_state.detach())
            self._states = tuple(new_states)
Пример #3
0
    def forward(self, inputs: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:
        """
        # Parameters

        inputs : `torch.Tensor`, required.
            A Tensor of shape `(batch_size, sequence_length, hidden_size)`.
        mask : `torch.BoolTensor`, required.
            A binary mask of shape `(batch_size, sequence_length)` representing the
            non-padded elements in each sequence in the batch.

        # Returns

        A `torch.Tensor` of shape (num_layers, batch_size, sequence_length, hidden_size),
        where the num_layers dimension represents the LSTM output from that layer.
        """
        batch_size, total_sequence_length = mask.size()
        stacked_sequence_output, final_states, restoration_indices = self.sort_and_run_forward(
            self._lstm_forward, inputs, mask)

        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size(
        )
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        if num_valid < batch_size:
            zeros = stacked_sequence_output.new_zeros(num_layers,
                                                      batch_size - num_valid,
                                                      returned_timesteps,
                                                      encoder_dim)
            stacked_sequence_output = torch.cat(
                [stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.new_zeros(num_layers, batch_size - num_valid,
                                        state_dim)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.new_zeros(
                num_layers,
                batch_size,
                sequence_length_difference,
                stacked_sequence_output[0].size(-1),
            )
            stacked_sequence_output = torch.cat(
                [stacked_sequence_output, zeros], 2)

        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)
Пример #4
0
 def _manipulate_mask(self, mask: torch.BoolTensor,
                      student_scores: torch.Tensor,
                      batch: Batch) -> torch.BoolTensor:
     """
     Add one extra (masked-out) token to the mask, for compatibility with BART.
     """
     assert student_scores.size(1) == batch.label_vec.size(1) + 1
     mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1)
     return mask
Пример #5
0
    def _greedy_decode(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        # Returns

        heads : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-numpy.inf)
        )
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = ~mask.unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(
            head_tag_representation, child_tag_representation, heads
        )
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags
Пример #6
0
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.BoolTensor,
                hidden_state: torch.Tensor = None) -> torch.Tensor:

        if mask is None:
            # If a mask isn't passed, there is no padding in the batch of instances, so we can just
            # return the last sequence output as the state.  This doesn't work in the case of
            # variable length sequences, as the last state for each element of the batch won't be
            # at the end of the max sequence length, so we have to use the state of the RNN below.
            return self._module(inputs, hidden_state)[0][:, -1, :]

        batch_size = mask.size(0)

        (
            _,
            state,
            restoration_indices,
        ) = self.sort_and_run_forward(self._module, inputs, mask, hidden_state)

        # Deal with the fact the LSTM state is a tuple of (state, memory).
        if isinstance(state, tuple):
            state = state[0]

        num_layers_times_directions, num_valid, encoding_dim = state.size()
        # Add back invalid rows.
        if num_valid < batch_size:
            # batch size is the second dimension here, because pytorch
            # returns RNN state as a tensor of shape (num_layers * num_directions,
            # batch_size, hidden_size)
            zeros = state.new_zeros(num_layers_times_directions,
                                    batch_size - num_valid, encoding_dim)
            state = torch.cat([state, zeros], 1)

        # Restore the original indices and return the final state of the
        # top layer. Pytorch's recurrent layers return state in the form
        # (num_layers * num_directions, batch_size, hidden_size) regardless
        # of the 'batch_first' flag, so we transpose, extract the relevant
        # layer state (both forward and backward if using bidirectional layers)
        # and return them as a single (batch_size, self.get_output_dim()) tensor.

        # now of shape: (batch_size, num_layers * num_directions, hidden_size).
        unsorted_state = state.transpose(0, 1).index_select(
            0, restoration_indices)

        # Extract the last hidden vector, including both forward and backward states
        # if the cell is bidirectional. Then reshape by concatenation (in the case
        # we have bidirectional states) or just squash the 1st dimension in the non-
        # bidirectional case. Return tensor has shape (batch_size, hidden_size * num_directions).
        try:
            last_state_index = 2 if self._module.bidirectional else 1
        except AttributeError:
            last_state_index = 1
        last_layer_state = unsorted_state[:, -last_state_index:, :]
        return last_layer_state.contiguous().view([-1, self.get_output_dim()])
Пример #7
0
    def _manipulate_mask(self, mask: torch.BoolTensor,
                         student_scores: torch.Tensor,
                         batch: Batch) -> torch.BoolTensor:
        """
        Add one extra (masked-out) token to the mask, for compatibility with BART.

        Only necessary when examining decoder outputs directly.
        """
        if student_scores.size(1) == batch.label_vec.size(1) + 1:
            mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1)
        return mask
Пример #8
0
 def _get_target_token_embeddings(self, token_embeddings: torch.Tensor,
                                  mask: torch.BoolTensor,
                                  direction: int) -> torch.Tensor:
     # Need to shift the mask in the correct direction
     zero_col = token_embeddings.new_zeros(mask.size(0),
                                           1).to(dtype=torch.bool)
     if direction == 0:
         # forward direction, get token to right
         shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
     else:
         shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
     return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(
         -1, self._forward_dim)
Пример #9
0
def batch_grad(
    func: Callable,
    inputs: FloatTensor,
    idx: Union[int, Tuple[int], List] = None,
    mask: BoolTensor = None,
) -> FloatTensor:
    """Compute gradients for a batch of inputs

    Args:
        func (Callable):
        inputs (FloatTensor): The first dimension corresponds the different
          instances.
        idx (Union[int, Tuple[int], List]): The index from the second dimension
          to the last. If a list is given, then the gradient of the sum of
          function values of these indices is computed for each instance.
        mask (BoolTensor):

    Returns:
        FloatTensor: The gradient for each input instance.
    """

    assert torch.is_tensor(inputs)
    assert (idx is None) != (
        mask is
        None), "Either idx or mask (and only one of them) has to be provided."

    inputs.requires_grad_()
    out = func(inputs)

    if idx is not None:
        if not isinstance(idx, list):
            idx = [idx]

        indices = []
        for i in range(inputs.size(0)):
            for j in idx:
                j = (j, ) if isinstance(j, int) else j
                indices.append((i, ) + j)
        t = out[list(zip(*indices))].sum(-1)
    else:
        # [M, B, ...]
        out = out.view(-1, *mask.size())
        t = out.masked_select(mask).sum()

    gradients = torch.autograd.grad(t, inputs)[0]

    return gradients
Пример #10
0
    def _greedy_decode(
            arc_scores: torch.Tensor, arc_tag_logits: torch.Tensor,
            mask: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently.

        # Parameters

        arc_scores : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        arc_tag_logits : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to
            generate a distribution over tags for each arc.
        mask : `torch.BoolTensor`, required.
            A mask of shape (batch_size, sequence_length).

        # Returns

        arc_probs : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length, sequence_length) representing the
            probability of an arc being present for this edge.
        arc_tag_probs : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length, sequence_length, sequence_length)
            representing the distribution over edge tags for a given edge.
        """
        # Mask the diagonal, because we don't self edges.
        inf_diagonal_mask = torch.diag(
            arc_scores.new(mask.size(1)).fill_(-numpy.inf))
        arc_scores = arc_scores + inf_diagonal_mask
        # shape (batch_size, sequence_length, sequence_length, num_tags)
        arc_tag_logits = arc_tag_logits + inf_diagonal_mask.unsqueeze(
            0).unsqueeze(-1)
        # Mask padded tokens, because we only want to consider actual word -> word edges.
        minus_mask = ~mask.unsqueeze(2)
        arc_scores.masked_fill_(minus_mask, -numpy.inf)
        arc_tag_logits.masked_fill_(minus_mask.unsqueeze(-1), -numpy.inf)
        # shape (batch_size, sequence_length, sequence_length)
        arc_probs = arc_scores.sigmoid()
        # shape (batch_size, sequence_length, sequence_length, num_tags)
        arc_tag_probs = torch.nn.functional.softmax(arc_tag_logits, dim=-1)
        return arc_probs, arc_tag_probs
Пример #11
0
    def get_attention_masks(self, mask: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns 2 masks of shape (batch_size, timesteps, timesteps) representing
        1) non-padded elements, and
        2) elements of the sequence which are permitted to be involved in attention at a given timestep.
        """
        device = mask.device
        # Forward case:
        timesteps = mask.size(1)
        # Shape (1, timesteps, timesteps)
        subsequent = subsequent_mask(timesteps, device)
        # Broadcasted logical and - we want zero
        # elements where either we have padding from the mask,
        # or we aren't allowed to use the timesteps.
        # Shape (batch_size, timesteps, timesteps)
        forward_mask = mask.unsqueeze(-1) & subsequent
        # Backward case - exactly the same, but transposed.
        backward_mask = forward_mask.transpose(1, 2)

        return forward_mask, backward_mask
Пример #12
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        mask_positions: torch.BoolTensor,
        target_ids: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            The output of `TextField.as_tensor()` for a batch of sentences.
        mask_positions : `torch.LongTensor`
            The positions in `tokens` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : `TextFieldTensors`
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a `TextField`, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            `mask_positions` - one target token per mask position.
        """

        targets = None
        if target_ids is not None:
            targets = util.get_token_ids_from_text_field_tensors(target_ids)
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(
                f"Number of targets ({targets.size()}) and number of masks "
                f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(
            self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks,
                                               vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict
Пример #13
0
    def _get_and_record_component_attention_loss(
        self,
        teacher_attention_matrices: List[Dict[str, torch.Tensor]],
        student_attention_matrices: List[Dict[str, torch.Tensor]],
        mask: torch.BoolTensor,
        tokens_per_example: torch.Tensor,
        num_tokens: torch.Tensor,
        mapped_layers: List[int],
        attn_type: str,
        metric_name: str,
    ) -> torch.Tensor:
        """
        Calculate the given attention loss and register it as the given metric name.
        """

        assert isinstance(self, TorchGeneratorAgent)
        # Code relies on methods

        # Select the right attention matrices
        selected_student_attn_matrices = [
            layer_matrices[attn_type]
            for layer_matrices in student_attention_matrices
        ]
        selected_teacher_attn_matrices = [
            layer_matrices[attn_type]
            for layer_matrices in teacher_attention_matrices
        ]

        batch_size = mask.size(0)
        per_layer_losses = []
        per_layer_per_example_losses = []
        for student_layer_idx, teacher_layer_idx in enumerate(mapped_layers):
            raw_layer_loss = F.mse_loss(
                input=selected_student_attn_matrices[student_layer_idx],
                target=selected_teacher_attn_matrices[teacher_layer_idx],
                reduction='none',
            )
            clamped_layer_loss = torch.clamp(raw_layer_loss,
                                             min=0,
                                             max=NEAR_INF_FP16)
            # Prevent infs from appearing in the loss term. Especially important with
            # fp16
            reshaped_layer_loss = clamped_layer_loss.view(
                batch_size, -1, clamped_layer_loss.size(-2),
                clamped_layer_loss.size(-1))
            # [batch size, n heads, query length, key length]
            mean_layer_loss = reshaped_layer_loss.mean(dim=(1, 3))
            # Take the mean over the attention heads and the key length
            assert mean_layer_loss.shape == mask.shape
            masked_layer_loss = mean_layer_loss * mask
            layer_loss_per_example = masked_layer_loss.sum(
                dim=-1)  # Sum over token dim
            layer_loss = masked_layer_loss.div(num_tokens).sum()
            # Divide before summing over examples so that values don't get too large
            per_layer_losses.append(layer_loss)
            per_layer_per_example_losses.append(layer_loss_per_example)
        attn_loss = torch.stack(per_layer_losses).mean()
        attn_loss_per_example = torch.stack(per_layer_per_example_losses,
                                            dim=1).mean(dim=1)

        # Record metric
        self.record_local_metric(
            metric_name,
            AverageMetric.many(attn_loss_per_example, tokens_per_example))

        return attn_loss
Пример #14
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        parent_mask: torch.BoolTensor,
        parent_start_mask: torch.BoolTensor,
        parent_end_mask: torch.BoolTensor,
        child_mask: torch.BoolTensor = None,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        parent_starts: torch.BoolTensor = None,
        parent_ends: torch.BoolTensor = None,
        child_idxs: torch.BoolTensor = None,
        child_starts: torch.BoolTensor = None,
        child_ends: torch.BoolTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            parent_mask: [batch_size, num_words]
            parent_start_mask: [batch_size, num_words]
            parent_end_mask: [batch_size, num_words]
            child_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            parent_starts: [batch_size]
            parent_ends: [batch_size]
            child_idxs: [batch_size, num_words]
            child_starts: [batch_size, num_words]
            child_ends: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_words]
            parent_tag_probs: [batch_size, num_words, num_tags]
            parent_start_probs: [batch_size, num_words]
            parent_end_probs: [batch_size, num_words]
            child_probs: [batch_size, num_words]
            child_start_probs: [batch_size, num_words]
            child_end_probs: [batch_size, num_words]
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
            start_loss (if parent_starts is not None)
            end_loss (if parent_ends is not None)
            child_loss (if child_idxs is not None)
            child_start_loss (if child_starts is not None)
            child_end_loss (if child_ends is not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                # bert = self.bert if self.arch == "bert" else self.roberta
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)
        parent_start_scores = self.parent_start_feedforward(
            encoded_text).squeeze(-1)
        parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze(
            -1)

        # mask out impossible positions
        minus_inf = -1e8
        parent_mask = torch.logical_and(parent_mask, word_mask)
        parent_scores = parent_scores + (~parent_mask).float() * minus_inf
        parent_start_mask = torch.logical_and(parent_start_mask, word_mask)
        parent_start_scores = parent_start_scores + (
            ~parent_start_mask).float() * minus_inf
        parent_end_mask = torch.logical_and(parent_end_mask, word_mask)
        parent_end_scores = parent_end_scores + (
            ~parent_end_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_start_probs = F.softmax(parent_start_scores, dim=-1)
        parent_end_probs = F.softmax(parent_end_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        output = (parent_probs, parent_tag_probs, parent_start_probs,
                  parent_end_probs)

        if self.config.predict_child:
            child_scores = self.child_feedforward(encoded_text).squeeze(-1)
            child_start_scores = self.child_start_feedforward(
                encoded_text).squeeze(-1)
            child_end_scores = self.child_end_feedforward(
                encoded_text).squeeze(-1)
            # todo add child mask - child should be inside the origin span
            if child_mask is None:
                child_mask = torch.ones_like(word_mask)
            else:
                child_mask = torch.logical_and(child_mask, word_mask)
            child_scores = child_scores + (~child_mask).float() * minus_inf
            child_start_scores = child_start_scores + (
                ~child_mask).float() * minus_inf
            child_end_scores = child_end_scores + (
                ~child_mask).float() * minus_inf
            child_probs = torch.sigmoid(child_scores)
            child_start_probs = torch.sigmoid(child_start_scores)
            child_end_probs = torch.sigmoid(child_end_scores)
            output = output + (child_probs, child_start_probs, child_end_probs)

        # add losses
        batch_range_vector = get_range_vector(
            batch_size, get_device_of(encoded_text))  # [bsz]
        if parent_idxs is not None:
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = parent_arc_nll.mean()
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags)
                output = output + (parent_tag_nll, )

        if parent_starts is not None:
            # [bsz, seq_len]
            parent_start_logits = F.log_softmax(parent_start_scores, dim=-1)
            parent_start_nll = -parent_start_logits[batch_range_vector,
                                                    parent_starts].mean()
            output = output + (parent_start_nll, )

        if parent_ends is not None:
            # [bsz, seq_len]
            parent_end_logits = F.log_softmax(parent_end_scores, dim=-1)
            parent_end_nll = -parent_end_logits[batch_range_vector,
                                                parent_ends].mean()
            output = output + (parent_end_nll, )

        if self.config.predict_child:
            if child_idxs is not None:
                child_loss = F.binary_cross_entropy_with_logits(
                    child_scores, child_idxs.float(), reduction="none")
                child_loss = (child_loss *
                              child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_loss, )
            if child_starts is not None:
                child_start_loss = F.binary_cross_entropy_with_logits(
                    child_start_scores, child_starts.float(), reduction="none")
                child_start_loss = (child_start_loss * child_mask).sum() / (
                    child_mask.sum() + 1e-8)
                output = output + (child_start_loss, )
            if child_ends is not None:
                child_end_loss = F.binary_cross_entropy_with_logits(
                    child_end_scores, child_ends.float(), reduction="none")
                child_end_loss = (child_end_loss *
                                  child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_end_loss, )

        return output
Пример #15
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        dep_idxs: torch.LongTensor,
        dep_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
    ):

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, _, encoding_dim = encoded_text.size()
        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask],
                              1)
        dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1)
        dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1)

        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = ~word_mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)

        arc_nll, tag_nll = self._construct_loss(
            head_tag_representation=head_tag_representation,
            child_tag_representation=child_tag_representation,
            attended_arcs=attended_arcs,
            head_indices=dep_idxs,
            head_tags=dep_tags,
            mask=word_mask,
        )

        return predicted_heads, predicted_head_tags, arc_nll, tag_nll
    def _unfold_long_sequences(
        self,
        embeddings: torch.FloatTensor,
        mask: torch.BoolTensor,
        batch_size: int,
        num_segment_concat_wordpieces: int,
    ) -> torch.FloatTensor:
        """
        We take 2D segments of a long sequence and flatten them out to get the whole sequence
        representation while remove unnecessary special tokens.

        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]

        We truncate the start and end tokens for all segments, recombine the segments,
        and manually add back the start and end tokens.

        # Parameters

        embeddings: `torch.FloatTensor`
            Shape: [batch_size * num_segments, self._max_length, embedding_size].
        mask: `torch.BoolTensor`
            Shape: [batch_size * num_segments, self._max_length].
            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
            in `forward()`.
        batch_size: `int`
        num_segment_concat_wordpieces: `int`
            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
            the original `token_ids.size(1)`.

        # Returns:

        embeddings: `torch.FloatTensor`
            Shape: [batch_size, self._num_wordpieces, embedding_size].
        """
        def lengths_to_mask(lengths, max_len, device):
            return torch.arange(max_len, device=device).expand(
                lengths.size(0), max_len) < lengths.unsqueeze(1)

        device = embeddings.device
        num_segments = int(embeddings.size(0) / batch_size)
        embedding_size = embeddings.size(2)

        # We want to remove all segment-level special tokens but maintain sequence-level ones
        num_wordpieces = num_segment_concat_wordpieces - (
            num_segments - 1) * self._num_added_tokens

        embeddings = embeddings.reshape(batch_size,
                                        num_segments * self._max_length,
                                        embedding_size)
        mask = mask.reshape(batch_size, num_segments * self._max_length)
        # We assume that all 1s in the mask precede all 0s, and add an assert for that.
        # Open an issue on GitHub if this breaks for you.
        # Shape: (batch_size,)
        seq_lengths = mask.sum(-1)
        if not (lengths_to_mask(seq_lengths, mask.size(1), device)
                == mask).all():
            raise ValueError(
                "Long sequence splitting only supports masks with all 1s preceding all 0s."
            )
        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
        end_token_indices = (
            seq_lengths.unsqueeze(-1) -
            torch.arange(self._num_added_end_tokens, device=device) - 1)

        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
        start_token_embeddings = embeddings[:, :self.
                                            _num_added_start_tokens, :]
        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
        end_token_embeddings = batched_index_select(embeddings,
                                                    end_token_indices)

        embeddings = embeddings.reshape(batch_size, num_segments,
                                        self._max_length, embedding_size)
        embeddings = embeddings[:, :, self._num_added_start_tokens:-self.
                                _num_added_end_tokens, :]  # truncate segment-level start/end tokens
        embeddings = embeddings.reshape(batch_size, -1,
                                        embedding_size)  # flatten

        # Now try to put end token embeddings back which is a little tricky.

        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
        # Shape: (batch_size,)
        num_effective_segments = (seq_lengths + self._max_length -
                                  1) / self._max_length
        # The number of indices that end tokens should shift back.
        num_removed_non_end_tokens = (
            num_effective_segments * self._num_added_tokens -
            self._num_added_end_tokens)
        # Shape: (batch_size, self._num_added_end_tokens)
        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
        assert (end_token_indices >= self._num_added_start_tokens).all()
        # Add space for end embeddings
        embeddings = torch.cat(
            [embeddings, torch.zeros_like(end_token_embeddings)], 1)
        # Add end token embeddings back
        embeddings.scatter_(
            1,
            end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings),
            end_token_embeddings)

        # Now put back start tokens. We can do this before putting back end tokens, but then
        # we need to change `num_removed_non_end_tokens` a little.
        embeddings = torch.cat([start_token_embeddings, embeddings], 1)

        # Truncate to original length
        embeddings = embeddings[:, :num_wordpieces, :]
        return embeddings
Пример #17
0
    def sort_and_run_forward(
        self,
        module: Callable[[PackedSequence, Optional[RnnState]],
                         Tuple[Union[PackedSequence, torch.Tensor],
                               RnnState], ],
        inputs: torch.Tensor,
        mask: torch.BoolTensor,
        hidden_state: Optional[RnnState] = None,
    ):
        """
        This function exists because Pytorch RNNs require that their inputs be sorted
        before being passed as input. As all of our Seq2xxxEncoders use this functionality,
        it is provided in a base class. This method can be called on any module which
        takes as input a `PackedSequence` and some `hidden_state`, which can either be a
        tuple of tensors or a tensor.

        As all of our Seq2xxxEncoders have different return types, we return `sorted`
        outputs from the module, which is called directly. Additionally, we return the
        indices into the batch dimension required to restore the tensor to it's correct,
        unsorted order and the number of valid batch elements (i.e the number of elements
        in the batch which are not completely masked). This un-sorting and re-padding
        of the module outputs is left to the subclasses because their outputs have different
        types and handling them smoothly here is difficult.

        # Parameters

        module : `Callable[RnnInputs, RnnOutputs]`
            A function to run on the inputs, where
            `RnnInputs: [PackedSequence, Optional[RnnState]]` and
            `RnnOutputs: Tuple[Union[PackedSequence, torch.Tensor], RnnState]`.
            In most cases, this is a `torch.nn.Module`.
        inputs : `torch.Tensor`, required.
            A tensor of shape `(batch_size, sequence_length, embedding_size)` representing
            the inputs to the Encoder.
        mask : `torch.BoolTensor`, required.
            A tensor of shape `(batch_size, sequence_length)`, representing masked and
            non-masked elements of the sequence for each element in the batch.
        hidden_state : `Optional[RnnState]`, (default = `None`).
            A single tensor of shape (num_layers, batch_size, hidden_size) representing the
            state of an RNN with or a tuple of
            tensors of shapes (num_layers, batch_size, hidden_size) and
            (num_layers, batch_size, memory_size), representing the hidden state and memory
            state of an LSTM-like RNN.

        # Returns

        module_output : `Union[torch.Tensor, PackedSequence]`.
            A Tensor or PackedSequence representing the output of the Pytorch Module.
            The batch size dimension will be equal to `num_valid`, as sequences of zero
            length are clipped off before the module is called, as Pytorch cannot handle
            zero length sequences.
        final_states : `Optional[RnnState]`
            A Tensor representing the hidden state of the Pytorch Module. This can either
            be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
            the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
        restoration_indices : `torch.LongTensor`
            A tensor of shape `(batch_size,)`, describing the re-indexing required to transform
            the outputs back to their original batch order.
        """
        # In some circumstances you may have sequences of zero length. `pack_padded_sequence`
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        (
            sorted_inputs,
            sorted_sequence_lengths,
            restoration_indices,
            sorting_indices,
        ) = sort_batch_by_length(inputs, sequence_lengths)

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(
            sorted_inputs[:num_valid, :, :],
            sorted_sequence_lengths[:num_valid].data.tolist(),
            batch_first=True,
        )
        # Prepare the initial states.
        if not self.stateful:
            if hidden_state is None:
                initial_states: Any = hidden_state
            elif isinstance(hidden_state, tuple):
                initial_states = [
                    state.index_select(
                        1, sorting_indices)[:, :num_valid, :].contiguous()
                    for state in hidden_state
                ]
            else:
                initial_states = hidden_state.index_select(
                    1, sorting_indices)[:, :num_valid, :].contiguous()

        else:
            initial_states = self._get_initial_states(batch_size, num_valid,
                                                      sorting_indices)

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input,
                                             initial_states)

        return module_output, final_states, restoration_indices
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        subtree_spans: torch.LongTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            subtree_spans: [batch_size, num_words, 2]
        Returns:
            span_start_logits: [batch_size, num_words, num_words]
            span_end_logits: [batch_size, num_words, num_words]
            span_loss: if subtree_spans is given.

        """
        # [bsz, seq_len, hidden]
        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # [bsz, seq_len, dim]
        subtree_start_representation = self._dropout(
            self.subtree_start_feedforward(encoded_text))
        subtree_end_representation = self._dropout(
            self.subtree_end_feedforward(encoded_text))
        # [bsz, seq_len, seq_len]
        span_start_scores = self.subtree_start_attention(
            subtree_start_representation, subtree_start_representation)
        span_end_scores = self.subtree_end_attention(
            subtree_end_representation, subtree_end_representation)

        # start of word should less equal to it
        start_mask = word_mask.unsqueeze(-1) & (
            ~torch.triu(span_start_scores.bool(), 1))
        # end of word should greater equal to it.
        end_mask = word_mask.unsqueeze(-1) & torch.triu(span_end_scores.bool())

        minus_inf = -1e8
        span_start_scores = span_start_scores + (
            ~start_mask).float() * minus_inf
        span_end_scores = span_end_scores + (~end_mask).float() * minus_inf

        output = (F.log_softmax(span_start_scores,
                                dim=-1), F.log_softmax(span_end_scores,
                                                       dim=-1))
        if subtree_spans is not None:

            start_loss = F.cross_entropy(
                span_start_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 0].view(-1))
            end_loss = F.cross_entropy(
                span_end_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 1].view(-1))
            span_loss = start_loss + end_loss
            output = output + (span_loss, )

        return output
Пример #19
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        span_idx: torch.LongTensor,
        span_tag: torch.LongTensor,
        child_arcs: torch.LongTensor,
        child_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            span_idx: [batch_size, 2]
            span_tag: [batch_size, 1]
            child_arcs: [batch_size, num_words]
            child_tags: [batch_size, num_words]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words]
            arc_nll: [1]
            tag_nll: [1]
        """

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # [bsz, seq_len, tag_classes]
        child_tag_scores = self.child_tag_feedforward(encoded_text)
        # [bsz, seq_len]
        child_scores = self.child_feedforward(encoded_text).squeeze(-1)

        # todo support cases that span_idx and span_tag are None
        # [bsz]
        batch_range_vector = get_range_vector(batch_size,
                                              get_device_of(encoded_text))
        # [bsz]
        gold_positions = span_idx[:, 0]

        # compute parent arc loss
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf
        child_scores = child_scores + (~mrc_mask).float() * minus_inf

        # [bsz, seq_len]
        parent_logits = F.log_softmax(parent_scores, dim=-1)
        parent_arc_nll = -parent_logits[batch_range_vector,
                                        gold_positions].mean()

        # compute parent tag loss
        parent_tag_nll = F.cross_entropy(
            parent_tag_scores[batch_range_vector, gold_positions], span_tag)

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)
        child_probs = F.sigmoid(child_scores)
        child_tag_probs = F.softmax(child_tag_scores, dim=-1)

        child_arc_loss = F.binary_cross_entropy_with_logits(child_scores,
                                                            child_arcs.float(),
                                                            reduction="none")
        child_arc_loss = (child_arc_loss *
                          mrc_mask.float()).sum() / mrc_mask.float().sum()
        child_tag_loss = F.cross_entropy(child_tag_scores.view(
            batch_size * seq_len, -1),
                                         child_tags.view(-1),
                                         reduction="none")
        child_tag_loss = (child_tag_loss * child_arcs.float().view(-1)
                          ).sum() / (child_arcs.float().sum() + 1e-8)

        return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss
Пример #20
0
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.BoolTensor,
                hidden_state: torch.Tensor = None) -> torch.Tensor:

        if self.stateful and mask is None:
            raise ValueError("Always pass a mask with stateful RNNs.")
        if self.stateful and hidden_state is not None:
            raise ValueError(
                "Stateful RNNs provide their own initial hidden_state.")

        if mask is None:
            return self._module(inputs, hidden_state)[0]

        batch_size, total_sequence_length = mask.size()

        packed_sequence_output, final_states, restoration_indices = self.sort_and_run_forward(
            self._module, inputs, mask, hidden_state)

        unpacked_sequence_tensor, _ = pad_packed_sequence(
            packed_sequence_output, batch_first=True)

        num_valid = unpacked_sequence_tensor.size(0)
        # Some RNNs (GRUs) only return one state as a Tensor.  Others (LSTMs) return two.
        # If one state, use a single element list to handle in a consistent manner below.
        if not isinstance(final_states, (list, tuple)) and self.stateful:
            final_states = [final_states]

        # Add back invalid rows.
        if num_valid < batch_size:
            _, length, output_dim = unpacked_sequence_tensor.size()
            zeros = unpacked_sequence_tensor.new_zeros(batch_size - num_valid,
                                                       length, output_dim)
            unpacked_sequence_tensor = torch.cat(
                [unpacked_sequence_tensor, zeros], 0)

            # The states also need to have invalid rows added back.
            if self.stateful:
                new_states = []
                for state in final_states:
                    num_layers, _, state_dim = state.size()
                    zeros = state.new_zeros(num_layers, batch_size - num_valid,
                                            state_dim)
                    new_states.append(torch.cat([state, zeros], 1))
                final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size(
            1)
        if sequence_length_difference > 0:
            zeros = unpacked_sequence_tensor.new_zeros(
                batch_size, sequence_length_difference,
                unpacked_sequence_tensor.size(-1))
            unpacked_sequence_tensor = torch.cat(
                [unpacked_sequence_tensor, zeros], 1)

        if self.stateful:
            self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        return unpacked_sequence_tensor.index_select(0, restoration_indices)
Пример #21
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        # is_subtree: torch.BoolTensor = None
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            # is_subtree: [batch_size]
        Returns:
            # is_subtree_probs: [batch_size]
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words, num_tags]
            # subtree_loss(if is_subtree is not None)
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)
        cls_embedding = self._dropout(cls_embedding)

        # [bsz]
        # subtree_scores = self.is_subtree_feedforward(cls_embedding).squeeze(-1)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # mask out impossible positions
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        # output = (torch.sigmoid(subtree_scores), parent_probs, parent_tag_probs)  # todo check if log in dp evaluation
        output = (parent_probs, parent_tag_probs
                  )  # todo check if log in dp evaluation

        # add losses
        # if is_subtree is not None:
        #     subtree_loss = F.binary_cross_entropy_with_logits(subtree_scores, is_subtree.float())
        #     output = output + (subtree_loss, )
        # else:
        is_subtree = torch.ones_like(parent_tags).bool()

        if parent_idxs is not None:
            sample_mask = is_subtree.float()
            # [bsz]
            batch_range_vector = get_range_vector(batch_size,
                                                  get_device_of(encoded_text))
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = (parent_arc_nll *
                              sample_mask).sum() / (sample_mask.sum() + 1e-8)
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags,
                    reduction="none")
                parent_tag_nll = (parent_tag_nll * sample_mask).sum() / (
                    sample_mask.sum() + 1e-8)
                output = output + (parent_tag_nll, )

        return output