Example #1
0
    def wrapper(*args, **kwargs):
        import torch
        import torch.jit
        from torch.autograd import Function, function

        # fast pass
        if not might_trace(args):
            return fn(*args, **kwargs)

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return fn(*args, **kwargs)

        tstate = torch._C._get_tracing_state(flat_args)

        arg_values = [torch._C._get_value_trace(tstate, x) for x in flat_args]

        # This must come after the calls to get_value_trace, lest we
        # lose information due to in-place operations.
        output_vars = fn(*args, **kwargs)

        symbolic_args = function._unflatten(arg_values, args)
        output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)

        for var, val in zip(
                function._iter_variables(output_vars),
                function._iter_jit_values(output_vals)):
            val.inferTypeFrom(var.data)
            torch._C._set_value_trace(tstate, var, val)

        return output_vars
Example #2
0
    def wrapper(*args, **kwargs):
        output = fn(*args, **kwargs)
        # fast pass
        if first_arg_only and not torch._C._jit_is_tracing(args[0]):
            return output

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return output
        flat_output_tensors = tuple(
            v.data for v in function._iter_variables(output))
        # TODO: kwargs aren't traced

        class ExportProxy(Function):
            @staticmethod
            def symbolic(g, *flat_args):
                symbolic_args = function._unflatten(flat_args, args)
                symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                return tuple(function._iter_jit_values(symbolic_output))

            @staticmethod
            def forward(ctx, *unused_args):
                return flat_output_tensors

            @staticmethod
            def backward(ctx, *unused_args, **unused_kwargs):
                raise RuntimeError(
                    "symbolic_override is meant for inference export only")

        flat_proxy_output = ExportProxy.apply(*flat_args)
        return function._unflatten(flat_proxy_output, output)
Example #3
0
        def wrapper(*args, **kwargs):
            output = fn(*args, **kwargs)
            flat_args = tuple(function._iter_variables(args))
            if not any(map(torch._C._jit_is_tracing, flat_args)):
                return output
            flat_output_tensors = tuple(
                v.data for v in function._iter_variables(output))
            assert len(list(function._iter_variables_permissive(
                list(kwargs.values())))) == 0, \
                "Passing Variable through kwargs is not supported"

            class ExportProxy(Function):
                @staticmethod
                def symbolic(g, *flat_args):
                    symbolic_args = function._unflatten(flat_args, args)
                    symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                    return tuple(function._iter_jit_values(symbolic_output))

                @staticmethod
                def forward(ctx, *unused_args):
                    return flat_output_tensors

                @staticmethod
                def backward(ctx, *unused_args, **unused_kwargs):
                    raise RuntimeError(
                        "symbolic_override is meant for inference export only")

            flat_proxy_output = ExportProxy.apply(*flat_args)
            return function._unflatten(flat_proxy_output, output)
Example #4
0
    def run(self, args, extra):
        # tracing is disabled, run the real thing, possibly timing it
        if not self.enabled:
            with _time("run_real", self.time):
                return self._run(*args)
        # tracing, but no trace exists, create one, possibly verifying it
        # by running it after creating it
        if self.saved_trace is None:
            _, out = self.record_trace(args, extra)
            self.proto = function._to_proto(out)
            return out
        trace_inputs = self.get_trace_inputs(args, extra)

        # just run the already created trace
        if not self.verify:
            return function._unflatten(self.run_trace(trace_inputs),
                                       self.proto)

        # verify an already created trace...
        cloned_inputs = tuple(_clone_inputs(trace_inputs))
        with _time("run_real", self.time), _fork_rng():
            out_real = self._run(*args)
        flat_trace_out = self.run_trace(cloned_inputs)
        _verify(flat_trace_out, flatten(out_real))
        return out_real
Example #5
0
    def wrapper(*args, **kwargs):
        import torch
        import torch.jit
        from torch.autograd import Function, function

        # fast pass
        if not might_trace(args):
            return fn(*args, **kwargs)

        flat_args = tuple(function._iter_tensors_permissive(args))
        flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor))
        if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)):
            return fn(*args, **kwargs)

        tstate = torch._C._get_tracing_state(flat_args_only_tensors)

        arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args]

        # This must come after the calls to get_value_trace, lest we
        # lose information due to in-place operations.
        output_vars = fn(*args, **kwargs)

        symbolic_args = function._unflatten(arg_values, args)
        output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)

        for var, val in zip(
                function._iter_tensors(output_vars),
                function._iter_jit_values(output_vals)):
            val.inferTypeFrom(var.data)
            torch._C._set_value_trace(tstate, var, val)

        return output_vars
Example #6
0
    def wrapper(*args, **kwargs):
        output = fn(*args, **kwargs)
        # fast pass
        if first_arg_only and not torch._C._jit_is_tracing(args[0]):
            return output

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return output
        flat_output_tensors = tuple(
            v.data for v in function._iter_variables(output))
        assert len(list(function._iter_variables_permissive(
            list(kwargs.values())))) == 0, \
            "Passing Variable through kwargs is not supported"

        class ExportProxy(Function):
            @staticmethod
            def symbolic(g, *flat_args):
                symbolic_args = function._unflatten(flat_args, args)
                symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                return tuple(function._iter_jit_values(symbolic_output))

            @staticmethod
            def forward(ctx, *unused_args):
                return flat_output_tensors

            @staticmethod
            def backward(ctx, *unused_args, **unused_kwargs):
                raise RuntimeError(
                    "symbolic_override is meant for inference export only")

        flat_proxy_output = ExportProxy.apply(*flat_args)
        return function._unflatten(flat_proxy_output, output)
Example #7
0
        def wrapper(*args, **kwargs):
            tstate = torch._C._get_tracing_state()
            if not tstate:
                return fn(*args, **kwargs)

            flat_args = tuple(function._iter_tensors_permissive(args))
            arg_values = [torch._C._get_value_trace(x) if isinstance(x, torch.Tensor) else x for x in flat_args]

            # This must come after the calls to get_value_trace, lest we
            # lose information due to in-place operations.

            # temporarily disable tracing so that we don't cause errors
            # for things inside of fn that may not be tracable
            with torch.jit._disable_tracing():
                output_vars = fn(*args, **kwargs)

            symbolic_args = function._unflatten(arg_values, args)
            output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)

            for var, val in zip(
                    function._iter_tensors(output_vars),
                    function._iter_jit_values(output_vals)):
                val.inferTypeFrom(var.data)
                torch._C._set_value_trace(var, val)

            return output_vars
Example #8
0
    def run_closure(self, trace_info, args, trace_inputs):
        if self.verify:
            cloned_args = tuple(_clone_inputs(args))
            with _time("run_real", self.time), _fork_rng(self.verify):
                flat_real_out = flatten((self._run(*cloned_args), ))

        with _time("run_trace", self.time):
            flat_out = trace_info.closure()(*_varify(trace_inputs))
        if not isinstance(flat_out, tuple):
            flat_out = (flat_out, )

        if self.verify:
            _verify(flat_out, flat_real_out)

        return function._unflatten(flat_out, trace_info.proto)
Example #9
0
def hack_onnx_rnn(fargs, output, args, kwargs):
    input, all_weights, hx = fargs
    output_tensors = tuple(v.data for v in _iter_variables(output))
    flat_weights = tuple(_iter_variables(all_weights))
    flat_hx = tuple(_iter_variables(hx))

    class RNNSymbolic(Function):
        @staticmethod
        def symbolic(g, *fargs):
            # NOTE: fargs contains Variable inputs (input + weight + hidden)
            # NOTE: args/kwargs contain RNN parameters
            raise RuntimeError("hack_onnx_rnn NYI")

        @staticmethod
        def forward(ctx, *fargs):
            return output_tensors

        @staticmethod
        def backward(ctx, *gargs, **gkwargs):
            raise RuntimeError("FIXME: Traced RNNs don't support backward")

    flat_output = RNNSymbolic.apply(*((input, ) + flat_weights + flat_hx))
    return _unflatten(flat_output, output)
Example #10
0
def hack_onnx_rnn(fargs, output, args, kwargs):
    input, all_weights, hx = fargs
    output_tensors = tuple(v.data for v in _iter_variables(output))
    flat_weights = tuple(_iter_variables(all_weights))
    flat_hx = tuple(_iter_variables(hx))

    class RNNSymbolic(Function):
        @staticmethod
        def symbolic(g, *fargs):
            # NOTE: fargs contains Variable inputs (input + weight + hidden)
            # NOTE: args/kwargs contain RNN parameters
            raise RuntimeError("hack_onnx_rnn NYI")

        @staticmethod
        def forward(ctx, *fargs):
            return output_tensors

        @staticmethod
        def backward(ctx, *gargs, **gkwargs):
            raise RuntimeError("FIXME: Traced RNNs don't support backward")

    flat_output = RNNSymbolic.apply(*((input,) + flat_weights + flat_hx))
    return _unflatten(flat_output, output)
Example #11
0
 def symbolic(g, *flat_args):
     symbolic_args = function._unflatten(flat_args, args)
     symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
     return tuple(function._iter_jit_values(symbolic_output))
Example #12
0
 def symbolic(g, *flat_args):
     symbolic_args = function._unflatten(flat_args, args)
     symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
     return tuple(function._iter_jit_values(symbolic_output))