def _unwrap_batched( batched_outputs: Union[Tensor, Tuple[Tensor, ...]], out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable, allow_none_pass_through: bool = False, ) -> Tuple: num_outputs = _num_outputs(batched_outputs) out_dims_as_tuple = _as_tuple( out_dims, num_outputs, lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must " f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.", ) # NOTE [Ignored _remove_batch_dim, _add_batch_dim] # There is something wrong with our type bindings for functions that begin # with '_', see #40397. if isinstance(batched_outputs, Tensor): out_dim = out_dims_as_tuple[0] return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value] if allow_none_pass_through: return tuple( (torch._remove_batch_dim(out, vmap_level, batch_size, out_dim ) if out is not None else None) for out, out_dim in zip(batched_outputs, out_dims_as_tuple)) else: return tuple( torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
def _unwrap_batched(batched_outputs, vmap_level, batch_size): # NOTE [Ignored _remove_batch_dim, _add_batch_dim] # There is something wrong with our type bindings for functions that begin # with '_', see #40397. if isinstance(batched_outputs, Tensor): return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, 0) # type: ignore return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, 0) # type: ignore for out in batched_outputs)
def _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, fn_name): num_outputs = _num_outputs(batched_outputs) out_dims_as_tuple = _as_tuple( out_dims, num_outputs, lambda: OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH.format( fn=fn_name, out_dims=out_dims, num_outputs=num_outputs)) # NOTE [Ignored _remove_batch_dim, _add_batch_dim] # There is something wrong with our type bindings for functions that begin # with '_', see #40397. if isinstance(batched_outputs, Tensor): out_dim = out_dims_as_tuple[0] return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
def _unwrap_batched(batched_outputs: Union[Tensor, Tuple[Tensor, ...]], out_dims: out_dims_t, vmap_level: int, batch_size: int, fn_name: str) -> Tuple: num_outputs = _num_outputs(batched_outputs) out_dims_as_tuple = _as_tuple( out_dims, num_outputs, lambda: f'vmap({fn_name}, ..., out_dims={out_dims}): `out_dims` must ' f'have one dim per output (got {num_outputs} outputs) of {fn_name}.') # NOTE [Ignored _remove_batch_dim, _add_batch_dim] # There is something wrong with our type bindings for functions that begin # with '_', see #40397. if isinstance(batched_outputs, Tensor): out_dim = out_dims_as_tuple[0] return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore return tuple( torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore for out, out_dim in zip(batched_outputs, out_dims_as_tuple))