def _get_input_argnames(fn: Callable[..., Any], exclude: List[str] = None) -> List[str]: """ Function to get input argument names of function. Args: fn (Callable[..., Any]): Function to get argument names from exclude (List[str]): List of string of names to exclude Returns: (List[str]): List of input argument names """ argspec = inspect.getfullargspec(fn) assert (argspec.varargs is None and argspec.varkw is None), "not supported by PyTorch" return get_fn_argsnames(fn, exclude=exclude)
def test_get_fn_argsnames(): class Net1(nn.Module): def forward(self, x): return x class Net2(nn.Module): def forward(self, x, y): return x class Net3(nn.Module): def forward(self, x, y=None): return x class Net4(nn.Module): def forward(self, x, *, y=None): return x class Net5(nn.Module): def forward(self, *, x): return x class Net6(nn.Module): def forward(self, *, x, y): return x class Net7(nn.Module): def forward(self, *, x, y=None): return x nets = [Net1, Net2, Net3, Net4, Net5, Net6, Net7] params_true = [ ["x"], ["x", "y"], ["x", "y"], ["x", "y"], ["x"], ["x", "y"], ["x", "y"], ] params_predicted = list( map( lambda x: utils.get_fn_argsnames(x.forward, exclude=["self"]), nets ) ) assert params_predicted == params_true
def __call__(self, *args, **kwargs): method_model = _ForwardOverrideModel(self.model, self.method_name) try: assert len(args) == 0, "only KV support implemented" fn = getattr(self.model, self.method_name) argspec = inspect.getfullargspec(fn) assert (argspec.varargs is None and argspec.varkw is None), "not supported by PyTorch tracing" method_argnames = get_fn_argsnames(fn, exclude=["self"]) method_input = tuple(kwargs[name] for name in method_argnames) self.tracing_result = torch.jit.trace(method_model, method_input) except Exception: # for backward compatibility self.tracing_result = torch.jit.trace(method_model, *args, **kwargs) output = self.model.forward(*args, **kwargs) return output