예제 #1
0
    def wrapper(*args, **kwargs):
        output = fn(*args, **kwargs)
        # fast pass
        if first_arg_only and not torch._C._jit_is_tracing(args[0]):
            return output

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return output
        flat_output_tensors = tuple(
            v.data for v in function._iter_variables(output))
        # TODO: kwargs aren't traced

        class ExportProxy(Function):
            @staticmethod
            def symbolic(g, *flat_args):
                symbolic_args = function._unflatten(flat_args, args)
                symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                return tuple(function._iter_jit_values(symbolic_output))

            @staticmethod
            def forward(ctx, *unused_args):
                return flat_output_tensors

            @staticmethod
            def backward(ctx, *unused_args, **unused_kwargs):
                raise RuntimeError(
                    "symbolic_override is meant for inference export only")

        flat_proxy_output = ExportProxy.apply(*flat_args)
        return function._unflatten(flat_proxy_output, output)
예제 #2
0
    def wrapper(*args, **kwargs):
        import torch
        import torch.jit
        from torch.autograd import Function, function

        # fast pass
        if not might_trace(args):
            return fn(*args, **kwargs)

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return fn(*args, **kwargs)

        tstate = torch._C._get_tracing_state(flat_args)

        arg_values = [torch._C._get_value_trace(tstate, x) for x in flat_args]

        # This must come after the calls to get_value_trace, lest we
        # lose information due to in-place operations.
        output_vars = fn(*args, **kwargs)

        symbolic_args = function._unflatten(arg_values, args)
        output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)

        for var, val in zip(
                function._iter_variables(output_vars),
                function._iter_jit_values(output_vals)):
            val.inferTypeFrom(var.data)
            torch._C._set_value_trace(tstate, var, val)

        return output_vars
예제 #3
0
        def wrapper(*args, **kwargs):
            output = fn(*args, **kwargs)
            flat_args = tuple(function._iter_variables(args))
            if not any(map(torch._C._jit_is_tracing, flat_args)):
                return output
            flat_output_tensors = tuple(
                v.data for v in function._iter_variables(output))
            assert len(list(function._iter_variables_permissive(
                list(kwargs.values())))) == 0, \
                "Passing Variable through kwargs is not supported"

            class ExportProxy(Function):
                @staticmethod
                def symbolic(g, *flat_args):
                    symbolic_args = function._unflatten(flat_args, args)
                    symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                    return tuple(function._iter_jit_values(symbolic_output))

                @staticmethod
                def forward(ctx, *unused_args):
                    return flat_output_tensors

                @staticmethod
                def backward(ctx, *unused_args, **unused_kwargs):
                    raise RuntimeError(
                        "symbolic_override is meant for inference export only")

            flat_proxy_output = ExportProxy.apply(*flat_args)
            return function._unflatten(flat_proxy_output, output)
예제 #4
0
    def wrapper(*args, **kwargs):
        output = fn(*args, **kwargs)
        # fast pass
        if first_arg_only and not torch._C._jit_is_tracing(args[0]):
            return output

        flat_args = tuple(function._iter_variables(args))
        if not any(map(torch._C._jit_is_tracing, flat_args)):
            return output
        flat_output_tensors = tuple(
            v.data for v in function._iter_variables(output))
        assert len(list(function._iter_variables_permissive(
            list(kwargs.values())))) == 0, \
            "Passing Variable through kwargs is not supported"

        class ExportProxy(Function):
            @staticmethod
            def symbolic(g, *flat_args):
                symbolic_args = function._unflatten(flat_args, args)
                symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
                return tuple(function._iter_jit_values(symbolic_output))

            @staticmethod
            def forward(ctx, *unused_args):
                return flat_output_tensors

            @staticmethod
            def backward(ctx, *unused_args, **unused_kwargs):
                raise RuntimeError(
                    "symbolic_override is meant for inference export only")

        flat_proxy_output = ExportProxy.apply(*flat_args)
        return function._unflatten(flat_proxy_output, output)
예제 #5
0
    def wrapper(*args, **kwargs):
        import torch
        import torch.jit
        from torch.autograd import Function, function, Variable

        # fast pass
        if not might_trace(args):
            return fn(*args, **kwargs)

        flat_args = tuple(function._iter_variables_permissive(args))
        flat_args_only_variables = tuple(x for x in flat_args if isinstance(x, Variable))
        if not any(map(torch._C._jit_is_tracing, flat_args_only_variables)):
            return fn(*args, **kwargs)

        tstate = torch._C._get_tracing_state(flat_args_only_variables)

        arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, Variable) else x for x in flat_args]

        # This must come after the calls to get_value_trace, lest we
        # lose information due to in-place operations.
        output_vars = fn(*args, **kwargs)

        symbolic_args = function._unflatten(arg_values, args)
        output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)

        for var, val in zip(
                function._iter_variables(output_vars),
                function._iter_jit_values(output_vals)):
            val.inferTypeFrom(var.data)
            torch._C._set_value_trace(tstate, var, val)

        return output_vars
예제 #6
0
def hack_onnx_rnn(fargs, output, args, kwargs):
    input, all_weights, hx = fargs
    output_tensors = tuple(v.data for v in _iter_variables(output))
    flat_weights = tuple(_iter_variables(all_weights))
    flat_hx = tuple(_iter_variables(hx))

    class RNNSymbolic(Function):
        @staticmethod
        def symbolic(g, *fargs):
            # NOTE: fargs contains Variable inputs (input + weight + hidden)
            # NOTE: args/kwargs contain RNN parameters
            raise RuntimeError("hack_onnx_rnn NYI")

        @staticmethod
        def forward(ctx, *fargs):
            return output_tensors

        @staticmethod
        def backward(ctx, *gargs, **gkwargs):
            raise RuntimeError("FIXME: Traced RNNs don't support backward")

    flat_output = RNNSymbolic.apply(*((input, ) + flat_weights + flat_hx))
    return _unflatten(flat_output, output)
예제 #7
0
파일: rnn.py 프로젝트: Northrend/pytorch
def hack_onnx_rnn(fargs, output, args, kwargs):
    input, all_weights, hx = fargs
    output_tensors = tuple(v.data for v in _iter_variables(output))
    flat_weights = tuple(_iter_variables(all_weights))
    flat_hx = tuple(_iter_variables(hx))

    class RNNSymbolic(Function):
        @staticmethod
        def symbolic(g, *fargs):
            # NOTE: fargs contains Variable inputs (input + weight + hidden)
            # NOTE: args/kwargs contain RNN parameters
            raise RuntimeError("hack_onnx_rnn NYI")

        @staticmethod
        def forward(ctx, *fargs):
            return output_tensors

        @staticmethod
        def backward(ctx, *gargs, **gkwargs):
            raise RuntimeError("FIXME: Traced RNNs don't support backward")

    flat_output = RNNSymbolic.apply(*((input,) + flat_weights + flat_hx))
    return _unflatten(flat_output, output)
예제 #8
0
파일: __init__.py 프로젝트: tibuch/pytorch
def _flatten(obj, params=tuple()):
    obj_vars = tuple(itertools.chain(function._iter_variables(obj), params))
    obj_struct = function._nested_map(lambda o: isinstance(o, Variable),
                                      lambda x: HOLE)(obj)
    return obj_vars, obj_struct
예제 #9
0
파일: jit.py 프로젝트: xlovelace/pytorch
def flatten(x):
    """
    Flatten an arbitrarily nested structure of Variables into
    a tuple of Variables.
    """
    return tuple(function._iter_variables(x))