예제 #1
0
    def normalized_arguments(
            self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
            kwarg_types : Optional[Dict[str, Any]] = None,
            normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
        """
        Returns normalized arguments to Python targets. This means that
        `args/kwargs` will be matched up to the module/functional's
        signature and return exclusively kwargs in positional order
        if `normalize_to_only_use_kwargs` is true.
        Also populates default values. Does not support positional-only
        parameters or varargs parameters.

        Supports module calls.

        May require `arg_types` and `kwarg_types` in order to disambiguate overloads.

        Args:
            root (torch.nn.Module): Module upon which to resolve module targets.
            arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
            kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
            normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.

        Returns:

            Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
        """
        if self.op == 'call_function':
            assert callable(self.target)
            return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types)  # type: ignore[arg-type]
        elif self.op == 'call_module':
            assert isinstance(self.target, str)
            return normalize_module(root, self.target, self.args, self.kwargs)  # type: ignore[arg-type]

        return None
예제 #2
0
 def call_module(self, target: Target, args: Tuple[Argument, ...],
                 kwargs: Dict[str, Any]):
     assert isinstance(target, str)
     new_kwargs = normalize_module(self.module, target, args,
                                   kwargs)  # type: ignore
     if new_kwargs:
         return super().call_module(target, (), new_kwargs)
     else:
         return super().call_module(target, args, kwargs)
예제 #3
0
 def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
     assert isinstance(target, str)
     new_args_and_kwargs = normalize_module(self.module, target, args, kwargs,  # type: ignore[arg-type]
                                            self.normalize_to_only_use_kwargs)
     if new_args_and_kwargs:
         new_args, new_kwargs = new_args_and_kwargs
         return super().call_module(target, new_args, new_kwargs)
     else:
         return super().call_module(target, args, kwargs)