def test_normalize_args(self): m = resnet18() class FunctionalTracer(torch.fx.Tracer): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty leaves = set([torch.nn.BatchNorm2d]) return type(m) in leaves traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) input = torch.randn(5, 3, 224, 224) ref_outs = traced(input) traced = NormalizeArgs(traced).transform() test_outs = traced(input) self.assertEqual(test_outs, ref_outs) modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == 'call_function' and node.target.__module__ == 'torch.nn.functional': self.assertEqual(len(node.args), 0) if node.op == 'call_module': submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0)
def test_normalize_modules_exhaustive(self): """ Exhaustively test `NormalizeArgs` on all standard torch.nn Module classes """ for test_params in module_tests + new_module_tests: if 'constructor' not in test_params: constructor = getattr(torch.nn, test_params['module_name']) else: constructor = test_params['constructor'] if 'constructor_args' not in test_params: args = () else: args = test_params['constructor_args'] mod = constructor(*args) # Skip modules that are not standard `torch.nn` # instances, including functionals. (functionals # are tested in test_normalize_args) if mod.__class__.__name__ not in dir(torch.nn): continue if 'input_fn' not in test_params: inputs = torch.randn(test_params['input_size']) else: inputs = test_params['input_fn']() if not isinstance(inputs, (tuple, list)): inputs = (inputs, ) params = ', '.join(f'v{i}' for i in range(len(inputs))) # Generate a class to wrap this standard `nn.Module` instance test_classname = f'Test{mod.__class__.__name__}' test_mod_code = f""" class {test_classname}(torch.nn.Module): def __init__(self, mod): super().__init__() self.mod = mod def forward(self, {params}): return self.mod({params}) """ gbls = {'torch': torch} exec(test_mod_code, gbls) test_instance = gbls[test_classname](mod) traced = symbolic_trace(test_instance) # Now actually test arg normalization! traced = NormalizeArgs(traced).transform() # These Modules have an RNG in their forward, so testing # correctness by comparing outputs is not correct. Skip that # check for these stochastic_modules = { 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU' } if mod.__class__.__name__ not in stochastic_modules: self.assertEqual(traced(*inputs), mod(*inputs)) # Ensure all args/kwargs are normalized into kwargs modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == 'call_module': submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0)