Пример #1
0
def reversed_indices(sequence: PackedSequence) -> Tensor:
    device = sequence.data.device

    batch_sizes = sequence.batch_sizes.to(device=device)
    acc_batch_sizes = accumulate_sizes(sizes=batch_sizes)
    token_ptr, batch_ptr, sorted_lengths = batch_sizes_to_ptr(
        batch_sizes=batch_sizes)
    token_ptr = (sorted_lengths - 1)[batch_ptr] - token_ptr

    return acc_batch_sizes[token_ptr] + batch_ptr
Пример #2
0
def last_indices(sequence: PackedSequence, unsort: bool = True) -> Tensor:
    device = sequence.data.device

    batch_sizes = sequence.batch_sizes.to(device=device)
    acc_batch_sizes = accumulate_sizes(sizes=batch_sizes)
    batch_ptr = head_indices(sequence=sequence, unsort=unsort)
    token_ptr = batch_sizes_to_token_sizes(batch_sizes=batch_sizes,
                                           batch_ptr=batch_ptr) - 1

    return acc_batch_sizes[token_ptr] + batch_ptr
Пример #3
0
def init_indices(sequence: PackedSequence, drop_last_n: int = 1) -> Tensor:
    device = sequence.data.device
    n = sequence.batch_sizes.size()[0] - drop_last_n

    batch_sizes = sequence.batch_sizes.to(device=device)
    acc_batch_sizes = accumulate_sizes(sizes=batch_sizes)
    batch_sizes = resize_sizes(sizes=batch_sizes, n=n)
    token_ptr, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=batch_sizes)

    return acc_batch_sizes[token_ptr] + batch_ptr
Пример #4
0
def rolled_indices(sequence: PackedSequence, shifts: int) -> Tensor:
    device = sequence.data.device

    batch_sizes = sequence.batch_sizes.to(device=device)
    acc_batch_sizes = accumulate_sizes(sizes=batch_sizes)
    token_ptr, batch_ptr, sorted_lengths = batch_sizes_to_ptr(
        batch_sizes=batch_sizes)

    lengths = sorted_lengths[batch_ptr]
    token_ptr = (token_ptr - shifts + lengths) % lengths

    return acc_batch_sizes[token_ptr] + batch_ptr
Пример #5
0
def scatter_index_to_ptr(index: Tensor,
                         dtype: torch.dtype = torch.long,
                         device: Device = None) -> Tuple[Tensor, Tensor]:
    if device is None:
        device = index.device

    index = index.to(dtype=dtype, device=device)
    sorted_indices = torch.argsort(index, dim=0, descending=False)

    token_sizes = torch.zeros(index.max().item() + 1, dtype=dtype, device=device)
    token_sizes = token_sizes.scatter_add_(dim=0, index=index, src=torch.ones_like(index))

    return sorted_indices, accumulate_sizes(sizes=token_sizes)
Пример #6
0
def cat_packed_indices(batch_sizes: Tensor, unsorted_indices: Optional[Tensor], device: Device = None):
    if device is None:
        device = unsorted_indices.device

    batch_sizes = batch_sizes.to(device=device)
    batch_ptr, token_ptr, token_sizes = token_sizes_to_ptr(
        token_sizes=batch_sizes,
        token_ptr=unsorted_indices,
    )
    acc_batch_sizes = accumulate_sizes(sizes=batch_sizes)

    indices = acc_batch_sizes[token_ptr] + batch_ptr
    return indices, token_sizes
Пример #7
0
def pack_catted_indices(token_sizes: Tensor, device: Device = None):
    if device is None:
        device = token_sizes.device

    sorted_token_sizes, sorted_indices, unsorted_indices = sizes_to_sorting_indices(
        sizes=token_sizes,
        device=device,
    )
    token_ptr, batch_ptr, batch_sizes = token_sizes_to_ptr(
        token_sizes=sorted_token_sizes,
        batch_ptr=sorted_indices,
    )
    acc_token_sizes = accumulate_sizes(sizes=token_sizes)
    indices = acc_token_sizes[batch_ptr] + token_ptr

    return indices, batch_sizes, sorted_indices, unsorted_indices
Пример #8
0
def tree_reduce_catted_indices(token_sizes: Tensor) -> TreeReduceIndices:
    batch_ptr1, token_ptr1, _ = batch_sizes_to_ptr(batch_sizes=token_sizes)
    acc_token_sizes1 = accumulate_sizes(sizes=token_sizes)

    xs, ys, zs, token_ptr2, acc_token_sizes2, dst, num_steps = tree_reduce_indices(
        token_sizes1=token_sizes)
    src1 = acc_token_sizes1[batch_ptr1] + token_ptr1
    src2 = acc_token_sizes2[batch_ptr1] + token_ptr2

    return TreeReduceIndices(
        xs=xs,
        ys=ys,
        zs=zs,
        src=(src1, src2),
        dst=dst,
        num_steps=num_steps,
    )
Пример #9
0
def tree_reduce_packed_indices(batch_sizes: Tensor) -> TreeReduceIndices:
    batch_ptr1, token_ptr1, token_sizes1 = token_sizes_to_ptr(
        token_sizes=batch_sizes)
    acc_batch_sizes1 = accumulate_sizes(sizes=batch_sizes)

    xs, ys, zs, token_ptr2, acc_token_sizes2, dst, num_steps = tree_reduce_indices(
        token_sizes1=token_sizes1)
    src1 = acc_batch_sizes1[token_ptr1] + batch_ptr1
    src2 = acc_token_sizes2[batch_ptr1] + token_ptr2
    src2 = src2[invert_permutation(src1)]

    return TreeReduceIndices(
        xs=xs,
        ys=ys,
        zs=zs,
        src=(..., src2),
        dst=dst,
        num_steps=num_steps,
    )
Пример #10
0
def tree_reduce_indices(token_sizes1: Tensor):
    token_sizes2 = token_sizes1 * 2 - 1
    _, token_ptr2, _ = batch_sizes_to_ptr(batch_sizes=token_sizes2)

    acc_token_sizes2 = token_sizes2.cumsum(dim=0)
    num_steps = acc_token_sizes2[-1].item()
    dst = acc_token_sizes2 - 1
    acc_token_sizes2 = F.pad(acc_token_sizes2, [1, -1])

    offsets = acc_token_sizes2.clone()
    mask = torch.ones_like(token_ptr2, dtype=torch.bool)

    sizes = 2**torch.arange(torch.iinfo(token_sizes1.dtype).bits - 1,
                            device=token_sizes1.device)
    acc_sizes = accumulate_sizes(sizes=sizes)
    opt_sizes = (token_sizes2[:, None] -
                 acc_sizes[None, :]).clamp_min(0).min(sizes)
    opt_sizes = opt_sizes[:, 1:opt_sizes.any(dim=0).long().sum()]

    xs, ys, zs = [], [], []
    for index in range(opt_sizes.size()[1] - 1, -1, -1):
        opt_size = opt_sizes[:, index]
        token_ptr, batch_ptr, _ = token_sizes_to_ptr(token_sizes=torch.div(
            opt_size, 2, rounding_mode='trunc'), )
        ptr = offsets[batch_ptr] + token_ptr

        x = ptr + token_ptr
        z = ptr + opt_size[batch_ptr]
        xs.append(x)
        ys.append(x + 1)
        zs.append(z)

        mask[z] = False
        offsets += opt_size

    return xs, ys, zs, token_ptr2[mask], acc_token_sizes2, dst, num_steps