def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.BoolTensor = None,
        position_ids: torch.LongTensor = None,
    ):
        if attention_mask is None:
            attention_mask = input_ids != self.pad_token_id

        attention_mask = torch.cat([
            torch.ones(input_ids.size(0),
                       self.persistent_mem_size,
                       device=input_ids.device).bool(), attention_mask
        ],
                                   dim=1)

        extended_attention_mask = attention_mask.view(
            input_ids.shape[0], 1, 1,
            attention_mask.shape[1]).repeat(1, self.num_attention_heads,
                                            attention_mask.shape[1], 1)
        extended_attention_mask = extended_attention_mask & extended_attention_mask.permute(
            0, 1, 3, 2)

        embedding_output = self.embeddings(input_ids=input_ids,
                                           position_ids=position_ids)
        hidden_states = self.encoder(hidden_states=embedding_output,
                                     attention_mask=extended_attention_mask)

        return hidden_states
Example #2
0
    def reset_states(self, mask: torch.BoolTensor = None) -> None:
        """
        Resets the internal states of a stateful encoder.

        # Parameters

        mask : `torch.BoolTensor`, optional.
            A tensor of shape `(batch_size,)` indicating which states should
            be reset. If not provided, all states will be reset.
        """
        if mask is None:
            self._states = None
        else:
            # state has shape (num_layers, batch_size, hidden_size). We reshape
            # mask to have shape (1, batch_size, 1) so that operations
            # broadcast properly.
            mask_batch_size = mask.size(0)
            mask = mask.view(1, mask_batch_size, 1)
            new_states = []
            for old_state in self._states:
                old_state_batch_size = old_state.size(1)
                if old_state_batch_size != mask_batch_size:
                    raise ValueError(
                        f"Trying to reset states using mask with incorrect batch size. "
                        f"Expected batch size: {old_state_batch_size}. "
                        f"Provided batch size: {mask_batch_size}.")
                new_state = ~mask * old_state
                new_states.append(new_state.detach())
            self._states = tuple(new_states)
    def forward(self,
                input_ids,
                past=None,
                mask: torch.BoolTensor = None,
                token_type_ids=None,
                position_ids=None):
        """
        mask: [batch_size, seq_length] is attention mask
        """
        # past length calculation and dealing with past
        if past is None:
            past_length = input_ids.shape[1]
            past = [None] * 12
        else:
            # count self
            past_length = past[0].shape[3] + input_ids.shape[1]

        if mask is None:
            # print("mask is not provided")
            mask = torch.ones(input_ids.shape[0],
                              past_length,
                              dtype=torch.bool,
                              device=input_ids.device)

        # Fast way to compute lower triangle attention mask
        mask = mask.view(input_ids.shape[0], 1, 1,
                         mask.shape[1]).repeat(1, self.num_attention_heads,
                                               mask.shape[1], 1)
        mask = mask & mask.permute(0, 1, 3, 2)
        mask = torch.tril(mask)

        # calculate embedding output
        embedding_output = self.embeddings(input_ids,
                                           position_ids=position_ids)

        # Transformer layer
        last_layer_output, presents = self.encoder(embedding_output,
                                                   mask=mask,
                                                   past=past)

        return last_layer_output, presents
Example #4
0
def masked_index_fill(
    target: torch.Tensor, indices: torch.LongTensor, mask: torch.BoolTensor, fill_value: int = 1
) -> torch.Tensor:
    """
    The given `indices` in `target` will be will be filled with `fill_value` given a `mask`.


    # Parameters

    target : `torch.Tensor`, required.
        A 2 dimensional tensor of shape (batch_size, sequence_length).
        This is the tensor to be filled.
    indices : `torch.LongTensor`, required
        A 2 dimensional tensor of shape (batch_size, num_indices),
        These are the indices that will be filled in the original tensor.
    mask : `torch.Tensor`, required.
        A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`.
    fill_value : `int`, optional (default = `1`)
        The value we fill the tensor with.

    # Returns

    filled_target : `torch.Tensor`
        A tensor with shape (batch_size, sequence_length) where 'indices' are filled with `fill_value`
    """
    mask = mask.bool()
    prev_shape = target.size()
    # Shape: (batch_size * num_indices)
    flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1))
    # Shape: (batch_size * num_indices, 1)
    mask = mask.view(-1)
    # Shape: (batch_size * sequence_length, 1)
    flattened_target = target.view(-1, 1)
    # Shape: (nonzero_indices, 1)
    unmasked_indices = flattened_indices[mask].unsqueeze(-1)

    flattened_target = flattened_target.scatter(0, unmasked_indices, fill_value)

    filled_target = flattened_target.reshape(prev_shape)

    return filled_target