Exemple #1
0
def masked_topk(
    input_: torch.FloatTensor,
    mask: torch.BoolTensor,
    k: Union[int, torch.LongTensor],
    dim: int = -1,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]:
    if input_.size() != mask.size():
        raise ValueError("`input_` and `mask` must have the same shape.")
    if not -input_.dim() <= dim < input_.dim():
        raise ValueError("`dim` must be in `[-input_.dim(), input_.dim())`")
    dim = (dim + input_.dim()) % input_.dim()

    max_k = k if isinstance(k, int) else k.max()


    permutation = list(range(input_.dim()))
    permutation.pop(dim)
    permutation += [dim]

    reverse_permutation = list(range(input_.dim() - 1))
    reverse_permutation.insert(dim, -1)

    other_dims_size = list(input_.size())
    other_dims_size.pop(dim)
    permuted_size = other_dims_size + [max_k]  # for restoration

    if isinstance(k, int):
        k = k * torch.ones(*other_dims_size, dtype=torch.long, device=mask.device)
    else:
        if list(k.size()) != other_dims_size:
            raise ValueError(
                "`k` must have the same shape as `input_` with dimension `dim` removed."
            )

    num_items = input_.size(dim)
    input_ = input_.permute(*permutation).reshape(-1, num_items)
    mask = mask.permute(*permutation).reshape(-1, num_items)
    k = k.reshape(-1)

    input_ = replace_masked_values(input_, mask, min_value_of_dtype(input_.dtype))

    _, top_indices = input_.topk(max_k, 1)

    top_indices_mask = get_mask_from_sequence_lengths(k, max_k).bool()

    fill_value, _ = top_indices.max(dim=1, keepdim=True)
    top_indices = torch.where(top_indices_mask, top_indices, fill_value)

    top_indices, _ = top_indices.sort(1)

    sequence_mask = mask.gather(1, top_indices)
    top_mask = top_indices_mask & sequence_mask

    top_input = input_.gather(1, top_indices)

    return (
        top_input.reshape(*permuted_size).permute(*reverse_permutation),
        top_mask.reshape(*permuted_size).permute(*reverse_permutation),
        top_indices.reshape(*permuted_size).permute(*reverse_permutation),
    )
    def forward(self,
                input_ids,
                past=None,
                mask: torch.BoolTensor = None,
                token_type_ids=None,
                position_ids=None):
        """
        mask: [batch_size, seq_length] is attention mask
        """
        # past length calculation and dealing with past
        if past is None:
            past_length = input_ids.shape[1]
            past = [None] * 12
        else:
            # count self
            past_length = past[0].shape[3] + input_ids.shape[1]

        if mask is None:
            # print("mask is not provided")
            mask = torch.ones(input_ids.shape[0],
                              past_length,
                              dtype=torch.bool,
                              device=input_ids.device)

        # Fast way to compute lower triangle attention mask
        mask = mask.view(input_ids.shape[0], 1, 1,
                         mask.shape[1]).repeat(1, self.num_attention_heads,
                                               mask.shape[1], 1)
        mask = mask & mask.permute(0, 1, 3, 2)
        mask = torch.tril(mask)

        # calculate embedding output
        embedding_output = self.embeddings(input_ids,
                                           position_ids=position_ids)

        # Transformer layer
        last_layer_output, presents = self.encoder(embedding_output,
                                                   mask=mask,
                                                   past=past)

        return last_layer_output, presents