Exemplo n.º 1
0
def _get_input_argnames(fn: Callable[..., Any],
                        exclude: List[str] = None) -> List[str]:
    """
    Function to get input argument names of function.

    Args:
        fn (Callable[..., Any]): Function to get argument names from
        exclude (List[str]): List of string of names to exclude

    Returns:
        (List[str]): List of input argument names
    """
    argspec = inspect.getfullargspec(fn)
    assert (argspec.varargs is None
            and argspec.varkw is None), "not supported by PyTorch"

    return get_fn_argsnames(fn, exclude=exclude)
Exemplo n.º 2
0
def test_get_fn_argsnames():
    class Net1(nn.Module):
        def forward(self, x):
            return x

    class Net2(nn.Module):
        def forward(self, x, y):
            return x

    class Net3(nn.Module):
        def forward(self, x, y=None):
            return x

    class Net4(nn.Module):
        def forward(self, x, *, y=None):
            return x

    class Net5(nn.Module):
        def forward(self, *, x):
            return x

    class Net6(nn.Module):
        def forward(self, *, x, y):
            return x

    class Net7(nn.Module):
        def forward(self, *, x, y=None):
            return x

    nets = [Net1, Net2, Net3, Net4, Net5, Net6, Net7]
    params_true = [
        ["x"],
        ["x", "y"],
        ["x", "y"],
        ["x", "y"],
        ["x"],
        ["x", "y"],
        ["x", "y"],
    ]

    params_predicted = list(
        map(
            lambda x: utils.get_fn_argsnames(x.forward, exclude=["self"]), nets
        )
    )
    assert params_predicted == params_true
Exemplo n.º 3
0
    def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)

        try:
            assert len(args) == 0, "only KV support implemented"

            fn = getattr(self.model, self.method_name)
            argspec = inspect.getfullargspec(fn)
            assert (argspec.varargs is None and
                    argspec.varkw is None), "not supported by PyTorch tracing"

            method_argnames = get_fn_argsnames(fn, exclude=["self"])
            method_input = tuple(kwargs[name] for name in method_argnames)

            self.tracing_result = torch.jit.trace(method_model, method_input)
        except Exception:
            # for backward compatibility
            self.tracing_result = torch.jit.trace(method_model, *args,
                                                  **kwargs)
        output = self.model.forward(*args, **kwargs)

        return output