def wrapper(*args, **kwargs): import torch import torch.jit from torch.autograd import Function, function # fast pass if not might_trace(args): return fn(*args, **kwargs) flat_args = tuple(function._iter_variables(args)) if not any(map(torch._C._jit_is_tracing, flat_args)): return fn(*args, **kwargs) tstate = torch._C._get_tracing_state(flat_args) arg_values = [torch._C._get_value_trace(tstate, x) for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_variables(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(tstate, var, val) return output_vars
def wrapper(*args, **kwargs): import torch import torch.jit from torch.autograd import Function, function # fast pass if not might_trace(args): return fn(*args, **kwargs) flat_args = tuple(function._iter_tensors_permissive(args)) flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor)) if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)): return fn(*args, **kwargs) tstate = torch._C._get_tracing_state(flat_args_only_tensors) arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_tensors(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(tstate, var, val) return output_vars
def wrapper(*args, **kwargs): tstate = torch._C._get_tracing_state() if not tstate: return fn(*args, **kwargs) flat_args = tuple(function._iter_tensors_permissive(args)) arg_values = [torch._C._get_value_trace(x) if isinstance(x, torch.Tensor) else x for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. # temporarily disable tracing so that we don't cause errors # for things inside of fn that may not be tracable with torch.jit._disable_tracing(): output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_tensors(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(var, val) return output_vars
def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output))