Example #1
0
def scatter_mul(tensor: Tensor, index: Tensor) -> Tensor:
    indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device)
    tensor_view = tensor.view((tensor.size()[0], -1))

    ret, _, _, _ = torch.embedding_bag(
        weight=tensor_view.abs().log(),
        indices=indices, offsets=offsets, mode=0,
    )
    sgn, _, _, _ = torch.embedding_bag(
        weight=tensor_view.sign().neg().add(1.),
        indices=indices, offsets=offsets, mode=0,
    )
    sgn = (sgn % 4).neg().add(1.)

    return (sgn.detach() * ret.exp()).view((ret.size()[0], *tensor.size()[1:]))
Example #2
0
def scatter_logsumexp(tensor: Tensor, index: Tensor) -> Tensor:
    indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device)
    tensor_view = tensor.view((tensor.size()[0], -1))

    with torch.no_grad():
        m, _, _, _ = torch.embedding_bag(
            weight=tensor_view,
            indices=indices, offsets=offsets, mode=2,
        )

    z, _, _, _ = torch.embedding_bag(
        weight=(tensor_view - m[index]).exp(),
        indices=indices, offsets=offsets, mode=0,
    )
    ret = torch.masked_fill(z, z == 0, 1.).log() + m
    return ret.view((ret.size()[0], *tensor.size()[1:]))
Example #3
0
def scatter_min(tensor: Tensor, index: Tensor) -> Tensor:
    indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device)
    ret, _, _, _ = torch.embedding_bag(
        weight=tensor.neg().view((tensor.size()[0], -1)),
        indices=indices, offsets=offsets, mode=2,
    )
    return ret.neg().view((ret.size()[0], *tensor.size()[1:]))
Example #4
0
def mlm_bag_catted_sequence(sequence: CattedSequence, index: CattedSequence, mode: int,
                            tokenizer: PreTrainedTokenizer, model: PreTrainedModel) -> CattedSequence:
    assert torch.equal(sequence.token_sizes, index.token_sizes), f'{sequence.token_sizes} != {index.token_sizes}'

    sequence, indices, offsets, token_sizes = mlm_bag_catted_indices(
        index=index, tokenizer=tokenizer,
        device=sequence.data.device,
    )
    tensor = mlm_padded_sequence(sequence=sequence, tokenizer=tokenizer, model=model)

    data, _, _, _ = torch.embedding_bag(
        tensor.flatten(start_dim=0, end_dim=1),
        indices=indices, offsets=offsets, mode=mode,
    )
    return CattedSequence(data=data, token_sizes=token_sizes)
Example #5
0
def _nn_functional_embedding_bag(input,
                                 weight,
                                 offsets=None,
                                 max_norm=None,
                                 norm_type=2,
                                 scale_grad_by_freq=False,
                                 mode='mean',
                                 sparse=False,
                                 per_sample_weights=None,
                                 include_last_offset=False):
    # Check for backward compatibility.
    # Used to be embedding_bag(weight, input, ...)
    # Now is     embedding_bag(input, weight, ...)
    if weight.dtype == torch.long and input.is_floating_point():
        warnings.warn(
            "Argument order of nn.functional.embedding_bag was changed. "
            "Usage `embedding_bag(weight, input, ...)` is deprecated, "
            "and should now be `embedding_bag(input, weight, ...)`.")
        weight, input = input, weight

    if per_sample_weights is not None and input.size(
    ) != per_sample_weights.size():
        raise ValueError(
            "embedding_bag: If per_sample_weights ({}) is not None, "
            "then it must have the same shape as the input ({})".format(
                per_sample_weights.shape, input.shape))

    _not_impl_raise(max_norm, "max_norm")
    _not_impl_raise(per_sample_weights, "per_sample_weights")

    if input.dim() == 2:
        if offsets is not None:
            type_str = "<unknown>"
            # TODO: Remove this once script supports type() calls
            if not torch.jit.is_scripting():
                type_str = str(type(offsets))
            raise ValueError("if input is 2D, then offsets has to be None"
                             ", as input is treated is a mini-batch of"
                             " fixed length sequences. However, found "
                             "offsets of type {}".format(type_str))
        offsets_ = NestedTensor(input).nested_size()
        offsets = torch.zeros(len(offsets_), dtype=torch.int64)
        for i in range(1, len(offsets)):
            offsets[i] = offsets[i - 1] + offsets_[i - 1][0]
        offsets = offsets.to(input.device)
    elif input.dim() == 1:
        raise ValueError("input has to be 2D NestedTensor,"
                         " but got NestedTensor of dimension {}".format(
                             input.dim()))
    if mode == 'sum':
        mode_enum = 0
    elif mode == 'mean':
        mode_enum = 1
    elif mode == 'max':
        mode_enum = 2

        if scale_grad_by_freq:
            raise ValueError(
                "max mode does not support scaling the gradient by the frequency"
            )

        if sparse:
            raise ValueError("max mode does not support sparse weights")

    else:
        raise ValueError("mode has to be one of sum, mean or max")

    if per_sample_weights is not None and mode != 'sum':
        raise NotImplementedError(
            "embedding_bag: per_sample_weights was not None. "
            "per_sample_weights is only supported for mode='sum' "
            "(got mode='{}'). Please open a feature request on GitHub.".format(
                mode))

    ret, _, _, _ = torch.embedding_bag(weight, input, offsets,
                                       scale_grad_by_freq, mode_enum, sparse,
                                       per_sample_weights, include_last_offset)
    return ret