def wrapped_with_chunks(*args, **kwargs): _check_out_dims_is_int_or_int_pytree(out_dims, func) _, flat_in_dims, flat_args, args_spec = _process_batched_inputs( in_dims, args, func) # Chunk flat arguments chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) # Apply vmap on chunks chunks_output = [] rs = torch.get_rng_state() if randomness == "same" else None for flat_args in chunks_flat_args: batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) if rs is not None: torch.set_rng_state(rs) chunks_output.append( _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)) flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) # Removing temporary variables helps to reduce memory usage on device like CUDA del chunks_output # concat chunks on out_dim flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) assert len(flat_out_dims) == len(flat_output_chunks) flat_output = [] for out_dim in flat_out_dims: flat_output.append(torch.cat(flat_output_chunks[0], dim=out_dim)) # release source data del flat_output_chunks[0] del flat_output_chunks # finally unflatten the output return tree_unflatten(flat_output, arg_spec)
def _create_batched_inputs(in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'expected `in_dims` to be int or a (potentially nested) tuple ' f'matching the structure of inputs, got: {type(in_dims)}.') if len(args) == 0: raise ValueError( f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add ' f'inputs, or you are trying to vmap over a function with no inputs. ' f'The latter is unsupported.') flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'in_dims is not compatible with the structure of `inputs`. ' f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' f'has structure {args_spec}.') for arg, in_dim in zip(flat_args, flat_in_dims): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but in_dim must be either ' f'an integer dimension or None.') if isinstance(in_dim, int) and not isinstance(arg, Tensor): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but the input is of type ' f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' f'please use None as the respective in_dim') if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for some input, but that input is a Tensor ' f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' f'0 <= in_dim < {arg.dim()}.') batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] batched_inputs = [ arg if in_dim is None else torch._add_batch_dim( arg, in_dim, vmap_level) for in_dim, arg in zip(flat_in_dims, flat_args) ] return tree_unflatten(batched_inputs, args_spec), batch_size
def _process_batched_inputs( in_dims: in_dims_t, args: Tuple, func: Callable) -> Tuple[int, List[Any], List[Any], TreeSpec]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'expected `in_dims` to be int or a (potentially nested) tuple ' f'matching the structure of inputs, got: {type(in_dims)}.') if len(args) == 0: raise ValueError( f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add ' f'inputs, or you are trying to vmap over a function with no inputs. ' f'The latter is unsupported.') flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'in_dims is not compatible with the structure of `inputs`. ' f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' f'has structure {args_spec}.') for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but in_dim must be either ' f'an integer dimension or None.') if isinstance(in_dim, int) and not isinstance(arg, Tensor): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but the input is of type ' f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' f'please use None as the respective in_dim') if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for some input, but that input is a Tensor ' f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' f'-{arg.dim()} <= in_dim < {arg.dim()}.') if in_dim is not None and in_dim < 0: flat_in_dims[i] = in_dim % arg.dim() return _validate_and_get_batch_size( flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec
def test_broadcast_to_and_flatten(self): cases = [ (1, (), []), # Same (flat) structures ((1,), (0,), [1]), ([1], [0], [1]), ((1, 2, 3), (0, 0, 0), [1, 2, 3]), ({'a': 1, 'b': 2}, {'a': 0, 'b': 0}, [1, 2]), # Mismatched (flat) structures ([1], (0,), None), ([1], (0,), None), ((1,), [0], None), ((1, 2, 3), (0, 0), None), ({'a': 1, 'b': 2}, {'a': 0}, None), ({'a': 1, 'b': 2}, {'a': 0, 'c': 0}, None), ({'a': 1, 'b': 2}, {'a': 0, 'b': 0, 'c': 0}, None), # Same (nested) structures ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), # Mismatched (nested) structures ((1, [2, 3]), (0, (0, 0)), None), ((1, [2, 3]), (0, [0, 0, 0]), None), # Broadcasting single value (1, (0, 0, 0), [1, 1, 1]), (1, [0, 0, 0], [1, 1, 1]), (1, {'a': 0, 'b': 0}, [1, 1]), (1, (0, [0, [0]], 0), [1, 1, 1, 1]), (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), # Broadcast multiple things ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), ] for pytree, to_pytree, expected in cases: _, to_spec = tree_flatten(to_pytree) result = _broadcast_to_and_flatten(pytree, to_spec) self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
def _unwrap_batched(batched_outputs: Union[Tensor, Tuple[Tensor, ...]], out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable) -> Tuple: flat_batched_outputs, output_spec = tree_flatten(batched_outputs) for out in flat_batched_outputs: if isinstance(out, torch.Tensor): continue raise ValueError( f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return ' f'Tensors, got type {type(out)} as a return.') def incompatible_error(): raise ValueError( f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): ' f'out_dims is not compatible with the structure of `outputs`. ' f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs ' f'has structure {output_spec}.') if isinstance(batched_outputs, torch.Tensor): # Some weird edge case requires us to spell out the following # see test_out_dims_edge_case if isinstance(out_dims, int): flat_out_dims = [out_dims] elif isinstance(out_dims, tuple) and len(out_dims) == 1: flat_out_dims = out_dims out_dims = out_dims[0] else: incompatible_error() else: flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) if flat_out_dims is None: incompatible_error() flat_outputs = [ _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) ] return tree_unflatten(flat_outputs, output_spec)