def get_exhaustive_batched_inputs_batch_norm_is_training( arg_values, kwarg_values, batch_size=2): flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] num_tensors = sum(is_tensors) if num_tensors == 1: # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors return bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values) @memoize def get_batched_arg(arg, bdim): assert isinstance(arg, torch.Tensor) assert bdim is not None result, _ = add_batch_dim(arg, bdim, batch_size) return result for bdim_choice in bdim_choices: flat_in_dims = construct_in_dims(bdim_choice, is_tensors) flat_batched_args = tuple( arg if in_dim is None else get_batched_arg(arg, in_dim) for arg, in_dim in zip(flat_args, flat_in_dims)) batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) yield batched_args, in_dims, kwarg_values
def generate_subclass_choices_args_kwargs(args, kwargs): flat_kwargs, spec = tree_flatten(kwargs) flat_args_kwargs = list(args) + list(flat_kwargs) for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs): new_args = choice[:len(args)] new_kwargs = tree_unflatten(choice[len(args):], spec) which_args_are_wrapped = debug_metadata[:len(args)] which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec) yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
def generate_subclass_choices_args_kwargs(args, kwargs, CCT): # CCT: CompositeCompliantTensor class which is generated using generate_cct flat_kwargs, spec = tree_flatten(kwargs) flat_args_kwargs = list(args) + list(flat_kwargs) for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT): new_args = choice[:len(args)] new_kwargs = tree_unflatten(choice[len(args):], spec) which_args_are_wrapped = debug_metadata[:len(args)] which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec) yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
def _create_batched_inputs( flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple: # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] batched_inputs = [arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level) # type: ignore for in_dim, arg in zip(flat_in_dims, flat_args)] return tree_unflatten(batched_inputs, args_spec)
def wrapped_with_chunks(*args, **kwargs): _check_out_dims_is_int_or_int_pytree(out_dims, func) _, flat_in_dims, flat_args, args_spec = _process_batched_inputs( in_dims, args, func) # Chunk flat arguments chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) # Apply vmap on chunks chunks_output = [] rs = torch.get_rng_state() if randomness == "same" else None for flat_args in chunks_flat_args: batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) if rs is not None: torch.set_rng_state(rs) chunks_output.append( _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)) flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) # Removing temporary variables helps to reduce memory usage on device like CUDA del chunks_output # concat chunks on out_dim flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) assert len(flat_out_dims) == len(flat_output_chunks) flat_output = [] for out_dim in flat_out_dims: flat_output.append(torch.cat(flat_output_chunks[0], dim=out_dim)) # release source data del flat_output_chunks[0] del flat_output_chunks # finally unflatten the output return tree_unflatten(flat_output, arg_spec)
def _autograd_grad( outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True ): inputs, inputs_spec = tree_flatten(inputs) diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) if grad_outputs is None: diff_outputs = tuple(out for out in outputs if out.requires_grad) else: diff_grad_outputs = [ (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad ] if len(diff_grad_outputs) == 0: diff_outputs, grad_outputs = (), () else: diff_outputs, grad_outputs = zip(*diff_grad_outputs) grad_inputs = torch.autograd.grad( diff_outputs, diff_inputs, grad_outputs, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True, ) result = [] grad_inputs_iter = iter(grad_inputs) for inp in inputs: if inp.requires_grad: grad_input = next(grad_inputs_iter) if grad_input is None: result.append(torch.zeros_like(inp)) else: result.append(grad_input) else: result.append(torch.zeros_like(inp)) return tree_unflatten(result, inputs_spec)
def run_test_with_leaf(leaf): values, treespec = tree_flatten(leaf) self.assertEqual(values, [leaf]) self.assertEqual(treespec, LeafSpec()) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, leaf)
def unflatten_outs(self, out): if self._pytree_info is None: return out if not isinstance(out, list): out = [out] assert (self._pytree_info.out_spec is not None) return pytree.tree_unflatten(out, self._pytree_info.out_spec)
def flat_fn(*flat_tensor_args): # The input are flattened tensor args. Prepare the args in the # order that original function expects. Add static args as well. # They will appear as tensor constants in the traced graph. nonlocal out_spec, static_args tensor_args, kwargs = pytree.tree_unflatten( flat_tensor_args, tensor_args_spec) if static_argnums is None: args = tensor_args else: args = rearrange(tensor_args, static_args, static_argnums) tree_out = fn(*args, **kwargs) flat_out, spec = pytree.tree_flatten(tree_out) for i in flat_out: is_known_type = False for j in KNOWN_TYPES: if isinstance(i, j): is_known_type = True break if not is_known_type: raise RuntimeError( f"Found {type(i)} in output, which is not a known type. " "If this type holds tensors, you need to register a pytree for it. " "See https://github.com/pytorch/functorch/issues/475 for a brief " "explanation why. If you don't need to register a pytree, please " "leave a comment explaining your use case and we'll make this more " "ergonomic to deal with") out_spec.set(spec) return flat_out
def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) assert(isinstance(self.graph._codegen, _PyTreeCodeGen)) self.graph._codegen.pytree_info = self.graph._codegen.pytree_info._replace(out_spec=out_spec) return out_args
def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) assert(self.graph._pytree_info is not None) self.graph._pytree_info = self.graph._pytree_info._replace(out_spec=out_spec) return out_args
def wrapped(*args): phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] fx_tracer = PythonKeyTracer() fake_tensor_mode: Any = nullcontext() if tracing_mode == "real": fake_tensor_mode = nullcontext() elif tracing_mode == "fake": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) elif tracing_mode == "symbolic": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) else: raise AssertionError(f"Unexpected tracing type: {tracing_mode}") proxy_mode = ProxyTorchDispatchMode( fx_tracer, trace_factory_functions=trace_factory_functions) def wrap_fake_concrete(x): if isinstance(x, torch.Tensor): return fake_tensor_mode.from_tensor( x) # type: ignore[attr-defined] return x shape_env = ShapeEnv() # todo: Figure out a more informative name for symints def wrap_fake_symbolic(x, sym_shape): if isinstance(x, torch.Tensor): val = FakeTensor(fake_tensor_mode, torch.empty(sym_shape, device="meta"), x.device) return val return x wrap_fn_map = { "real": lambda x: x, "fake": wrap_fake_concrete, } if tracing_mode == "symbolic": flat_shapes = shape_env.create_shapes_for_args(args) flat_args, spec = pytree.tree_flatten(args) args = pytree.tree_unflatten( list( map(lambda a: wrap_fake_symbolic(a[0], a[1]), zip(flat_args, flat_shapes))), spec) else: args = pytree.tree_map(wrap_fn_map[tracing_mode], args) with decompose( decomposition_table ), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined] t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way t.shape_env = shape_env # type: ignore[assignment] return t
def wrapped(*args): flat_args, args_spec = pytree.tree_flatten(args) assert (len(flat_args) == len(flat_inps)) for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): with no_dispatch(): flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf) else: flat_args[idx] = flat_inps[idx] tree_args = pytree.tree_unflatten(flat_args, args_spec) out = f(*tree_args) flat_outs, out_spec = pytree.tree_flatten(out) for idx in range(len(flat_outs)): if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor): flat_outs[idx] = flat_outs[idx].proxy return pytree.tree_unflatten(flat_outs, out_spec)
def wrapped(*args): flat_args, args_spec = pytree.tree_flatten(args) assert (len(flat_args) == len(flat_inps)) for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): flat_args[idx] = PythonTensor(flat_inps[idx], arg) else: flat_args[idx] = flat_inps[idx] tree_args = pytree.tree_unflatten(flat_args, args_spec) out = f(*tree_args) flat_outs, out_spec = pytree.tree_flatten(out) for idx in range(len(flat_outs)): if isinstance(flat_outs[idx], torch.Tensor) and isinstance( flat_outs[idx], PythonTensor): flat_outs[idx] = flat_outs[idx].proxy return pytree.tree_unflatten(flat_outs, out_spec)
def test_flatten_unflatten_torch_namedtuple_return_type(self): x = torch.randn(3, 3) expected = torch.max(x, dim=0) values, spec = tree_flatten(expected) result = tree_unflatten(values, spec) self.assertEqual(type(result), type(expected)) self.assertEqual(result, expected)
def run_test(pytree): values, treespec = tree_flatten(pytree) self.assertTrue(isinstance(values, list)) self.assertEqual(len(values), treespec.num_leaves) # NB: python basic data structures (dict list tuple) all have # contents equality defined on them, so the following works for them. unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, pytree)
def run_test(tup): expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup]) values, treespec = tree_flatten(tup) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertTrue(isinstance(unflattened, tuple))
def run_test(lst): expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst]) values, treespec = tree_flatten(lst) self.assertTrue(isinstance(values, list)) self.assertEqual(values, lst) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, lst) self.assertTrue(isinstance(unflattened, list))
def functional_call(*args, **kwargs): with _stateless.reparametrize_module( mod, pytree.tree_unflatten(args[:params_len], params_spec)): out = mod(*args[params_len:], **kwargs) if not isinstance(out, (tuple, list)): raise RuntimeError( "Graph output must be a tuple(). This is so that we can avoid " "pytree processing of the ouputs. Please change the module to " "have tuple outputs or use aot_module instead.") return out
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None): CCT = generate_cct(enable_recursive_torch_dispatch=True, autograd_view_consistency=False) # Permutations of arg and kwargs in CCT. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice def maybe_tangent(t): assert type(t) is not CCT # Generate `tangent` tensor # if given object is a Tensor and requires grad is set. if isinstance(t, torch.Tensor) and t.requires_grad: return torch.randn_like(t) elif is_tensorlist(t): return list(torch.randn_like(e) if e.requires_grad else None for e in t) return None tangent_args = tuple(maybe_tangent(arg) for arg in args) flat_kwargs, spec = tree_flatten(kwargs) flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) # Permutations tangent arg and tangent kwargs in CCT. for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT): new_tang_args, new_tang_kwargs, \ which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice with fwAD.dual_level(): def maybe_make_dual(dual): # Returns dual tensor if primal is a tensor/tensor subclass # with requires_grad set. primal, tangent = dual if isinstance(primal, torch.Tensor) and primal.requires_grad: return fwAD.make_dual(primal, tangent) elif is_tensorlist(primal): return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri for pri, tang in zip(primal, tangent)) return primal op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args))) op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} try: if gradcheck_wrapper is None: op(*op_args, **op_kwargs) else: gradcheck_wrapper(op, *op_args, **op_kwargs) # see NOTE: [What errors are Composite Compiance trying to catch?] except RuntimeError as err: raise_composite_compliance_error( err, f"- wrapped_args: {which_args_are_wrapped}\n" f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" )
def run_test(tup): expected_spec = TreeSpec(dict, list(tup.keys()), [LeafSpec() for _ in tup.values()]) values, treespec = tree_flatten(tup) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(tup.values())) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertTrue(isinstance(unflattened, dict))
def run_test(odict): expected_spec = TreeSpec(OrderedDict, list(odict.keys()), [LeafSpec() for _ in odict.values()]) values, treespec = tree_flatten(odict) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(odict.values())) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, odict) self.assertTrue(isinstance(unflattened, OrderedDict))
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2): flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] bdim_choices = get_bdim_choices(sum(is_tensors)) @memoize def get_batched_arg(arg, bdim): assert isinstance(arg, torch.Tensor) assert bdim is not None result, _ = add_batch_dim(arg, bdim, batch_size) return result for bdim_choice in bdim_choices: flat_in_dims = construct_in_dims(bdim_choice, is_tensors) flat_batched_args = tuple( arg if in_dim is None else get_batched_arg(arg, in_dim) for arg, in_dim in zip(flat_args, flat_in_dims)) batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) yield batched_args, in_dims, kwarg_values
def test_flatten_unflatten_return_type(self, op): x = torch.randn(3, 3) expected = op(x, dim=0) values, spec = tree_flatten(expected) # Check that values is actually List[Tensor] and not (ReturnType(...),) for value in values: self.assertTrue(isinstance(value, torch.Tensor)) result = tree_unflatten(values, spec) self.assertEqual(type(result), type(expected)) self.assertEqual(result, expected)
def wrapper(*cotangents, retain_graph=True, create_graph=True): flat_cotangents, cotangents_spec = tree_flatten(cotangents) if primals_out_spec != cotangents_spec: raise RuntimeError( f'Expected pytree structure of cotangents to be the same ' f'as pytree structure of outputs to the function. ' f'cotangents: {treespec_pprint(cotangents_spec)}, ' f'primal output: {treespec_pprint(primals_out_spec)}') result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents, retain_graph=retain_graph, create_graph=create_graph) return tree_unflatten(result, primals_spec)
def _create_batched_inputs(in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'expected `in_dims` to be int or a (potentially nested) tuple ' f'matching the structure of inputs, got: {type(in_dims)}.') if len(args) == 0: raise ValueError( f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add ' f'inputs, or you are trying to vmap over a function with no inputs. ' f'The latter is unsupported.') flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'in_dims is not compatible with the structure of `inputs`. ' f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' f'has structure {args_spec}.') for arg, in_dim in zip(flat_args, flat_in_dims): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but in_dim must be either ' f'an integer dimension or None.') if isinstance(in_dim, int) and not isinstance(arg, Tensor): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for an input but the input is of type ' f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' f'please use None as the respective in_dim') if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): ' f'Got in_dim={in_dim} for some input, but that input is a Tensor ' f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' f'0 <= in_dim < {arg.dim()}.') batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] batched_inputs = [ arg if in_dim is None else torch._add_batch_dim( arg, in_dim, vmap_level) for in_dim, arg in zip(flat_in_dims, flat_args) ] return tree_unflatten(batched_inputs, args_spec), batch_size
def wrapped(*primals): _args = list(flat_args) for num, arg in zip(diff_argnums, primals): _args[num] = arg _args = tree_unflatten(_args, args_spec) result = f(*_args, **kwargs) if output_process_fn_grad is not None: result = output_process_fn_grad(result) if isinstance(result, tuple): # TODO: Remove the following hack for namedtuples result = tuple(result) result = tuple(r for r in result if isinstance(r, Tensor) and ( r.is_floating_point() or r.is_complex())) assert len(result) > 0 return result
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, *batched_args, **kwarg_values): outs = [] flat_args, args_spec = pytree.tree_flatten(batched_args) flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1) flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2) assert (args_spec == dims_spec1) assert (args_spec == dims_spec2) assert (len(flat_dims1) == len(flat_dims2)) for idx1 in range(batch_size1): out_split = [] arg_split = [ a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1) ] for idx2 in range(batch_size2): new_args = [ a.select(in_dim, idx2) if in_dim is not None else a for a, in_dim in zip(arg_split, flat_dims2) ] out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) out_split.append(out) outs.append(out_split) loop_out = [] for out_split in outs: if isinstance(out_split[0], torch.Tensor): loop_out.append(torch.stack(out_split, out_dim1)) else: new_out = [] for idx in range(len(out_split[0])): new_out.append( torch.stack([i[idx] for i in out_split], out_dim1)) loop_out.append(new_out) new_out = [] if isinstance(loop_out, torch.Tensor): new_out = torch.stack(loop_out, out_dim2) else: for idx in range(len(loop_out[0])): new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2)) return new_out
def wrapper(*args): level = _grad_increment_nesting() output, aux, grad_input = None, None, None try: args = _wrap_all_tensors(args, level) diff_args = _slice_argnums(args, argnums) tree_map_(partial(_create_differentiable, level=level), diff_args) output = f(*args) if has_aux: output, aux = output if not isinstance(output, torch.Tensor): raise RuntimeError( 'grad_and_value(f)(*args): Expected f(*args)' f'to return a Tensor, got {type(output)}') if output.dim() != 0: raise RuntimeError( 'grad_and_value(f)(*args): Expected f(*args)' 'to return a scalar Tensor, got tensor with ' f'{output.dim()} dims. Maybe you wanted to' 'use the vjp or jacrev APIs instead?') flat_diff_args, spec = tree_flatten(diff_args) # NB: need create_graph so that backward pass isn't run in no_grad mode flat_outputs = _as_tuple(output) flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True) grad_input = tree_unflatten(flat_grad_input, spec) finally: if grad_input is not None: grad_input = _undo_create_differentiable(grad_input, level) if output is not None: output = _undo_create_differentiable(output, level) if aux is not None: aux = _undo_create_differentiable(aux, level) _grad_decrement_nesting() if has_aux: return grad_input, (output, aux) return grad_input, output
def functional_call(*args, **kwargs): with stateless._reparametrize_module( mod, pytree.tree_unflatten(args[:params_len], params_spec) ): if isinstance(mod, torch.fx.GraphModule): with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly( check_nan=False ): out = Interpreter(mod).run(*args[params_len:], **kwargs) else: out = mod(*args[params_len:], **kwargs) if not isinstance(out, (tuple, list)): raise RuntimeError( "Graph output must be a tuple(). This is so that we can avoid " "pytree processing of the ouputs. Please change the module to " "have tuple outputs or use aot_module instead." ) return out