Ejemplo n.º 1
0
def _unwrap_batched(
    batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
    out_dims: out_dims_t,
    vmap_level: int,
    batch_size: int,
    func: Callable,
    allow_none_pass_through: bool = False,
) -> Tuple:
    num_outputs = _num_outputs(batched_outputs)
    out_dims_as_tuple = _as_tuple(
        out_dims,
        num_outputs,
        lambda:
        f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
        f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
    )

    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    # There is something wrong with our type bindings for functions that begin
    # with '_', see #40397.
    if isinstance(batched_outputs, Tensor):
        out_dim = out_dims_as_tuple[0]
        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size,
                                       out_dim)  # type: ignore[return-value]
    if allow_none_pass_through:
        return tuple(
            (torch._remove_batch_dim(out, vmap_level, batch_size, out_dim
                                     ) if out is not None else None)
            for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
    else:
        return tuple(
            torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
            for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
Ejemplo n.º 2
0
def _unwrap_batched(batched_outputs, vmap_level, batch_size):
    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    # There is something wrong with our type bindings for functions that begin
    # with '_', see #40397.
    if isinstance(batched_outputs, Tensor):
        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, 0)  # type: ignore
    return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, 0)  # type: ignore
                 for out in batched_outputs)
Ejemplo n.º 3
0
def _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, fn_name):
    num_outputs = _num_outputs(batched_outputs)
    out_dims_as_tuple = _as_tuple(
        out_dims, num_outputs,
        lambda: OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH.format(
            fn=fn_name, out_dims=out_dims, num_outputs=num_outputs))

    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    # There is something wrong with our type bindings for functions that begin
    # with '_', see #40397.
    if isinstance(batched_outputs, Tensor):
        out_dim = out_dims_as_tuple[0]
        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim)  # type: ignore
    return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)  # type: ignore
                 for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
Ejemplo n.º 4
0
def _unwrap_batched(batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
                    out_dims: out_dims_t, vmap_level: int, batch_size: int,
                    fn_name: str) -> Tuple:
    num_outputs = _num_outputs(batched_outputs)
    out_dims_as_tuple = _as_tuple(
        out_dims, num_outputs,
        lambda: f'vmap({fn_name}, ..., out_dims={out_dims}): `out_dims` must '
        f'have one dim per output (got {num_outputs} outputs) of {fn_name}.')

    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    # There is something wrong with our type bindings for functions that begin
    # with '_', see #40397.
    if isinstance(batched_outputs, Tensor):
        out_dim = out_dims_as_tuple[0]
        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size,
                                       out_dim)  # type: ignore
    return tuple(
        torch._remove_batch_dim(out, vmap_level, batch_size,
                                out_dim)  # type: ignore
        for out, out_dim in zip(batched_outputs, out_dims_as_tuple))