def wrapper(*args, **kwargs): import torch import torch.jit from torch.autograd import Function, function # fast pass if not might_trace(args): return fn(*args, **kwargs) flat_args = tuple(function._iter_variables(args)) if not any(map(torch._C._jit_is_tracing, flat_args)): return fn(*args, **kwargs) tstate = torch._C._get_tracing_state(flat_args) arg_values = [torch._C._get_value_trace(tstate, x) for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_variables(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(tstate, var, val) return output_vars
def wrapper(*args, **kwargs): output = fn(*args, **kwargs) # fast pass if first_arg_only and not torch._C._jit_is_tracing(args[0]): return output flat_args = tuple(function._iter_variables(args)) if not any(map(torch._C._jit_is_tracing, flat_args)): return output flat_output_tensors = tuple( v.data for v in function._iter_variables(output)) # TODO: kwargs aren't traced class ExportProxy(Function): @staticmethod def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output)) @staticmethod def forward(ctx, *unused_args): return flat_output_tensors @staticmethod def backward(ctx, *unused_args, **unused_kwargs): raise RuntimeError( "symbolic_override is meant for inference export only") flat_proxy_output = ExportProxy.apply(*flat_args) return function._unflatten(flat_proxy_output, output)
def wrapper(*args, **kwargs): output = fn(*args, **kwargs) flat_args = tuple(function._iter_variables(args)) if not any(map(torch._C._jit_is_tracing, flat_args)): return output flat_output_tensors = tuple( v.data for v in function._iter_variables(output)) assert len(list(function._iter_variables_permissive( list(kwargs.values())))) == 0, \ "Passing Variable through kwargs is not supported" class ExportProxy(Function): @staticmethod def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output)) @staticmethod def forward(ctx, *unused_args): return flat_output_tensors @staticmethod def backward(ctx, *unused_args, **unused_kwargs): raise RuntimeError( "symbolic_override is meant for inference export only") flat_proxy_output = ExportProxy.apply(*flat_args) return function._unflatten(flat_proxy_output, output)
def run(self, args, extra): # tracing is disabled, run the real thing, possibly timing it if not self.enabled: with _time("run_real", self.time): return self._run(*args) # tracing, but no trace exists, create one, possibly verifying it # by running it after creating it if self.saved_trace is None: _, out = self.record_trace(args, extra) self.proto = function._to_proto(out) return out trace_inputs = self.get_trace_inputs(args, extra) # just run the already created trace if not self.verify: return function._unflatten(self.run_trace(trace_inputs), self.proto) # verify an already created trace... cloned_inputs = tuple(_clone_inputs(trace_inputs)) with _time("run_real", self.time), _fork_rng(): out_real = self._run(*args) flat_trace_out = self.run_trace(cloned_inputs) _verify(flat_trace_out, flatten(out_real)) return out_real
def wrapper(*args, **kwargs): import torch import torch.jit from torch.autograd import Function, function # fast pass if not might_trace(args): return fn(*args, **kwargs) flat_args = tuple(function._iter_tensors_permissive(args)) flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor)) if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)): return fn(*args, **kwargs) tstate = torch._C._get_tracing_state(flat_args_only_tensors) arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_tensors(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(tstate, var, val) return output_vars
def wrapper(*args, **kwargs): output = fn(*args, **kwargs) # fast pass if first_arg_only and not torch._C._jit_is_tracing(args[0]): return output flat_args = tuple(function._iter_variables(args)) if not any(map(torch._C._jit_is_tracing, flat_args)): return output flat_output_tensors = tuple( v.data for v in function._iter_variables(output)) assert len(list(function._iter_variables_permissive( list(kwargs.values())))) == 0, \ "Passing Variable through kwargs is not supported" class ExportProxy(Function): @staticmethod def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output)) @staticmethod def forward(ctx, *unused_args): return flat_output_tensors @staticmethod def backward(ctx, *unused_args, **unused_kwargs): raise RuntimeError( "symbolic_override is meant for inference export only") flat_proxy_output = ExportProxy.apply(*flat_args) return function._unflatten(flat_proxy_output, output)
def wrapper(*args, **kwargs): tstate = torch._C._get_tracing_state() if not tstate: return fn(*args, **kwargs) flat_args = tuple(function._iter_tensors_permissive(args)) arg_values = [torch._C._get_value_trace(x) if isinstance(x, torch.Tensor) else x for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. # temporarily disable tracing so that we don't cause errors # for things inside of fn that may not be tracable with torch.jit._disable_tracing(): output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_tensors(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(var, val) return output_vars
def run_closure(self, trace_info, args, trace_inputs): if self.verify: cloned_args = tuple(_clone_inputs(args)) with _time("run_real", self.time), _fork_rng(self.verify): flat_real_out = flatten((self._run(*cloned_args), )) with _time("run_trace", self.time): flat_out = trace_info.closure()(*_varify(trace_inputs)) if not isinstance(flat_out, tuple): flat_out = (flat_out, ) if self.verify: _verify(flat_out, flat_real_out) return function._unflatten(flat_out, trace_info.proto)
def hack_onnx_rnn(fargs, output, args, kwargs): input, all_weights, hx = fargs output_tensors = tuple(v.data for v in _iter_variables(output)) flat_weights = tuple(_iter_variables(all_weights)) flat_hx = tuple(_iter_variables(hx)) class RNNSymbolic(Function): @staticmethod def symbolic(g, *fargs): # NOTE: fargs contains Variable inputs (input + weight + hidden) # NOTE: args/kwargs contain RNN parameters raise RuntimeError("hack_onnx_rnn NYI") @staticmethod def forward(ctx, *fargs): return output_tensors @staticmethod def backward(ctx, *gargs, **gkwargs): raise RuntimeError("FIXME: Traced RNNs don't support backward") flat_output = RNNSymbolic.apply(*((input, ) + flat_weights + flat_hx)) return _unflatten(flat_output, output)
def hack_onnx_rnn(fargs, output, args, kwargs): input, all_weights, hx = fargs output_tensors = tuple(v.data for v in _iter_variables(output)) flat_weights = tuple(_iter_variables(all_weights)) flat_hx = tuple(_iter_variables(hx)) class RNNSymbolic(Function): @staticmethod def symbolic(g, *fargs): # NOTE: fargs contains Variable inputs (input + weight + hidden) # NOTE: args/kwargs contain RNN parameters raise RuntimeError("hack_onnx_rnn NYI") @staticmethod def forward(ctx, *fargs): return output_tensors @staticmethod def backward(ctx, *gargs, **gkwargs): raise RuntimeError("FIXME: Traced RNNs don't support backward") flat_output = RNNSymbolic.apply(*((input,) + flat_weights + flat_hx)) return _unflatten(flat_output, output)
def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output))
def symbolic(g, *flat_args): symbolic_args = function._unflatten(flat_args, args) symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs) return tuple(function._iter_jit_values(symbolic_output))