def scatter_mul(tensor: Tensor, index: Tensor) -> Tensor: indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device) tensor_view = tensor.view((tensor.size()[0], -1)) ret, _, _, _ = torch.embedding_bag( weight=tensor_view.abs().log(), indices=indices, offsets=offsets, mode=0, ) sgn, _, _, _ = torch.embedding_bag( weight=tensor_view.sign().neg().add(1.), indices=indices, offsets=offsets, mode=0, ) sgn = (sgn % 4).neg().add(1.) return (sgn.detach() * ret.exp()).view((ret.size()[0], *tensor.size()[1:]))
def scatter_logsumexp(tensor: Tensor, index: Tensor) -> Tensor: indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device) tensor_view = tensor.view((tensor.size()[0], -1)) with torch.no_grad(): m, _, _, _ = torch.embedding_bag( weight=tensor_view, indices=indices, offsets=offsets, mode=2, ) z, _, _, _ = torch.embedding_bag( weight=(tensor_view - m[index]).exp(), indices=indices, offsets=offsets, mode=0, ) ret = torch.masked_fill(z, z == 0, 1.).log() + m return ret.view((ret.size()[0], *tensor.size()[1:]))
def scatter_min(tensor: Tensor, index: Tensor) -> Tensor: indices, offsets = scatter_index_to_ptr(index=index, device=tensor.device) ret, _, _, _ = torch.embedding_bag( weight=tensor.neg().view((tensor.size()[0], -1)), indices=indices, offsets=offsets, mode=2, ) return ret.neg().view((ret.size()[0], *tensor.size()[1:]))
def mlm_bag_catted_sequence(sequence: CattedSequence, index: CattedSequence, mode: int, tokenizer: PreTrainedTokenizer, model: PreTrainedModel) -> CattedSequence: assert torch.equal(sequence.token_sizes, index.token_sizes), f'{sequence.token_sizes} != {index.token_sizes}' sequence, indices, offsets, token_sizes = mlm_bag_catted_indices( index=index, tokenizer=tokenizer, device=sequence.data.device, ) tensor = mlm_padded_sequence(sequence=sequence, tokenizer=tokenizer, model=model) data, _, _, _ = torch.embedding_bag( tensor.flatten(start_dim=0, end_dim=1), indices=indices, offsets=offsets, mode=mode, ) return CattedSequence(data=data, token_sizes=token_sizes)
def _nn_functional_embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None, include_last_offset=False): # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) if weight.dtype == torch.long and input.is_floating_point(): warnings.warn( "Argument order of nn.functional.embedding_bag was changed. " "Usage `embedding_bag(weight, input, ...)` is deprecated, " "and should now be `embedding_bag(input, weight, ...)`.") weight, input = input, weight if per_sample_weights is not None and input.size( ) != per_sample_weights.size(): raise ValueError( "embedding_bag: If per_sample_weights ({}) is not None, " "then it must have the same shape as the input ({})".format( per_sample_weights.shape, input.shape)) _not_impl_raise(max_norm, "max_norm") _not_impl_raise(per_sample_weights, "per_sample_weights") if input.dim() == 2: if offsets is not None: type_str = "<unknown>" # TODO: Remove this once script supports type() calls if not torch.jit.is_scripting(): type_str = str(type(offsets)) raise ValueError("if input is 2D, then offsets has to be None" ", as input is treated is a mini-batch of" " fixed length sequences. However, found " "offsets of type {}".format(type_str)) offsets_ = NestedTensor(input).nested_size() offsets = torch.zeros(len(offsets_), dtype=torch.int64) for i in range(1, len(offsets)): offsets[i] = offsets[i - 1] + offsets_[i - 1][0] offsets = offsets.to(input.device) elif input.dim() == 1: raise ValueError("input has to be 2D NestedTensor," " but got NestedTensor of dimension {}".format( input.dim())) if mode == 'sum': mode_enum = 0 elif mode == 'mean': mode_enum = 1 elif mode == 'max': mode_enum = 2 if scale_grad_by_freq: raise ValueError( "max mode does not support scaling the gradient by the frequency" ) if sparse: raise ValueError("max mode does not support sparse weights") else: raise ValueError("mode has to be one of sum, mean or max") if per_sample_weights is not None and mode != 'sum': raise NotImplementedError( "embedding_bag: per_sample_weights was not None. " "per_sample_weights is only supported for mode='sum' " "(got mode='{}'). Please open a feature request on GitHub.".format( mode)) ret, _, _, _ = torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset) return ret