def reversed_indices(sequence: PackedSequence) -> Tensor: device = sequence.data.device batch_sizes = sequence.batch_sizes.to(device=device) acc_batch_sizes = accumulate_sizes(sizes=batch_sizes) token_ptr, batch_ptr, sorted_lengths = batch_sizes_to_ptr( batch_sizes=batch_sizes) token_ptr = (sorted_lengths - 1)[batch_ptr] - token_ptr return acc_batch_sizes[token_ptr] + batch_ptr
def last_indices(sequence: PackedSequence, unsort: bool = True) -> Tensor: device = sequence.data.device batch_sizes = sequence.batch_sizes.to(device=device) acc_batch_sizes = accumulate_sizes(sizes=batch_sizes) batch_ptr = head_indices(sequence=sequence, unsort=unsort) token_ptr = batch_sizes_to_token_sizes(batch_sizes=batch_sizes, batch_ptr=batch_ptr) - 1 return acc_batch_sizes[token_ptr] + batch_ptr
def init_indices(sequence: PackedSequence, drop_last_n: int = 1) -> Tensor: device = sequence.data.device n = sequence.batch_sizes.size()[0] - drop_last_n batch_sizes = sequence.batch_sizes.to(device=device) acc_batch_sizes = accumulate_sizes(sizes=batch_sizes) batch_sizes = resize_sizes(sizes=batch_sizes, n=n) token_ptr, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=batch_sizes) return acc_batch_sizes[token_ptr] + batch_ptr
def rolled_indices(sequence: PackedSequence, shifts: int) -> Tensor: device = sequence.data.device batch_sizes = sequence.batch_sizes.to(device=device) acc_batch_sizes = accumulate_sizes(sizes=batch_sizes) token_ptr, batch_ptr, sorted_lengths = batch_sizes_to_ptr( batch_sizes=batch_sizes) lengths = sorted_lengths[batch_ptr] token_ptr = (token_ptr - shifts + lengths) % lengths return acc_batch_sizes[token_ptr] + batch_ptr
def scatter_index_to_ptr(index: Tensor, dtype: torch.dtype = torch.long, device: Device = None) -> Tuple[Tensor, Tensor]: if device is None: device = index.device index = index.to(dtype=dtype, device=device) sorted_indices = torch.argsort(index, dim=0, descending=False) token_sizes = torch.zeros(index.max().item() + 1, dtype=dtype, device=device) token_sizes = token_sizes.scatter_add_(dim=0, index=index, src=torch.ones_like(index)) return sorted_indices, accumulate_sizes(sizes=token_sizes)
def cat_packed_indices(batch_sizes: Tensor, unsorted_indices: Optional[Tensor], device: Device = None): if device is None: device = unsorted_indices.device batch_sizes = batch_sizes.to(device=device) batch_ptr, token_ptr, token_sizes = token_sizes_to_ptr( token_sizes=batch_sizes, token_ptr=unsorted_indices, ) acc_batch_sizes = accumulate_sizes(sizes=batch_sizes) indices = acc_batch_sizes[token_ptr] + batch_ptr return indices, token_sizes
def pack_catted_indices(token_sizes: Tensor, device: Device = None): if device is None: device = token_sizes.device sorted_token_sizes, sorted_indices, unsorted_indices = sizes_to_sorting_indices( sizes=token_sizes, device=device, ) token_ptr, batch_ptr, batch_sizes = token_sizes_to_ptr( token_sizes=sorted_token_sizes, batch_ptr=sorted_indices, ) acc_token_sizes = accumulate_sizes(sizes=token_sizes) indices = acc_token_sizes[batch_ptr] + token_ptr return indices, batch_sizes, sorted_indices, unsorted_indices
def tree_reduce_catted_indices(token_sizes: Tensor) -> TreeReduceIndices: batch_ptr1, token_ptr1, _ = batch_sizes_to_ptr(batch_sizes=token_sizes) acc_token_sizes1 = accumulate_sizes(sizes=token_sizes) xs, ys, zs, token_ptr2, acc_token_sizes2, dst, num_steps = tree_reduce_indices( token_sizes1=token_sizes) src1 = acc_token_sizes1[batch_ptr1] + token_ptr1 src2 = acc_token_sizes2[batch_ptr1] + token_ptr2 return TreeReduceIndices( xs=xs, ys=ys, zs=zs, src=(src1, src2), dst=dst, num_steps=num_steps, )
def tree_reduce_packed_indices(batch_sizes: Tensor) -> TreeReduceIndices: batch_ptr1, token_ptr1, token_sizes1 = token_sizes_to_ptr( token_sizes=batch_sizes) acc_batch_sizes1 = accumulate_sizes(sizes=batch_sizes) xs, ys, zs, token_ptr2, acc_token_sizes2, dst, num_steps = tree_reduce_indices( token_sizes1=token_sizes1) src1 = acc_batch_sizes1[token_ptr1] + batch_ptr1 src2 = acc_token_sizes2[batch_ptr1] + token_ptr2 src2 = src2[invert_permutation(src1)] return TreeReduceIndices( xs=xs, ys=ys, zs=zs, src=(..., src2), dst=dst, num_steps=num_steps, )
def tree_reduce_indices(token_sizes1: Tensor): token_sizes2 = token_sizes1 * 2 - 1 _, token_ptr2, _ = batch_sizes_to_ptr(batch_sizes=token_sizes2) acc_token_sizes2 = token_sizes2.cumsum(dim=0) num_steps = acc_token_sizes2[-1].item() dst = acc_token_sizes2 - 1 acc_token_sizes2 = F.pad(acc_token_sizes2, [1, -1]) offsets = acc_token_sizes2.clone() mask = torch.ones_like(token_ptr2, dtype=torch.bool) sizes = 2**torch.arange(torch.iinfo(token_sizes1.dtype).bits - 1, device=token_sizes1.device) acc_sizes = accumulate_sizes(sizes=sizes) opt_sizes = (token_sizes2[:, None] - acc_sizes[None, :]).clamp_min(0).min(sizes) opt_sizes = opt_sizes[:, 1:opt_sizes.any(dim=0).long().sum()] xs, ys, zs = [], [], [] for index in range(opt_sizes.size()[1] - 1, -1, -1): opt_size = opt_sizes[:, index] token_ptr, batch_ptr, _ = token_sizes_to_ptr(token_sizes=torch.div( opt_size, 2, rounding_mode='trunc'), ) ptr = offsets[batch_ptr] + token_ptr x = ptr + token_ptr z = ptr + opt_size[batch_ptr] xs.append(x) ys.append(x + 1) zs.append(z) mask[z] = False offsets += opt_size return xs, ys, zs, token_ptr2[mask], acc_token_sizes2, dst, num_steps