def test_normalize_args(self):
        m = resnet18()

        class FunctionalTracer(torch.fx.Tracer):
            def is_leaf_module(self, m: torch.nn.Module,
                               module_qualified_name: str) -> bool:
                # `leaves` contains the set of standard `nn.Modules` that are not
                # currently symbolically traceable. Ideally this set would be empty
                leaves = set([torch.nn.BatchNorm2d])
                return type(m) in leaves

        traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))

        input = torch.randn(5, 3, 224, 224)
        ref_outs = traced(input)

        traced = NormalizeArgs(traced).transform()

        test_outs = traced(input)
        self.assertEqual(test_outs, ref_outs)

        modules = dict(traced.named_modules())
        for node in traced.graph.nodes:
            if node.op == 'call_function' and node.target.__module__ == 'torch.nn.functional':
                self.assertEqual(len(node.args), 0)
            if node.op == 'call_module':
                submod_class = modules[node.target].__class__
                nn_class = getattr(torch.nn, submod_class.__name__)
                if submod_class == nn_class:
                    self.assertEqual(len(node.args), 0)
    def test_normalize_modules_exhaustive(self):
        """
        Exhaustively test `NormalizeArgs` on all standard
        torch.nn Module classes
        """
        for test_params in module_tests + new_module_tests:
            if 'constructor' not in test_params:
                constructor = getattr(torch.nn, test_params['module_name'])
            else:
                constructor = test_params['constructor']

            if 'constructor_args' not in test_params:
                args = ()
            else:
                args = test_params['constructor_args']

            mod = constructor(*args)
            # Skip modules that are not standard `torch.nn`
            # instances, including functionals. (functionals
            # are tested in test_normalize_args)
            if mod.__class__.__name__ not in dir(torch.nn):
                continue

            if 'input_fn' not in test_params:
                inputs = torch.randn(test_params['input_size'])
            else:
                inputs = test_params['input_fn']()

            if not isinstance(inputs, (tuple, list)):
                inputs = (inputs, )

            params = ', '.join(f'v{i}' for i in range(len(inputs)))

            # Generate a class to wrap this standard `nn.Module` instance
            test_classname = f'Test{mod.__class__.__name__}'
            test_mod_code = f"""
class {test_classname}(torch.nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, {params}):
        return self.mod({params})
            """

            gbls = {'torch': torch}
            exec(test_mod_code, gbls)

            test_instance = gbls[test_classname](mod)
            traced = symbolic_trace(test_instance)

            # Now actually test arg normalization!
            traced = NormalizeArgs(traced).transform()

            # These Modules have an RNG in their forward, so testing
            # correctness by comparing outputs is not correct. Skip that
            # check for these
            stochastic_modules = {
                'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU'
            }

            if mod.__class__.__name__ not in stochastic_modules:
                self.assertEqual(traced(*inputs), mod(*inputs))

            # Ensure all args/kwargs are normalized into kwargs
            modules = dict(traced.named_modules())
            for node in traced.graph.nodes:
                if node.op == 'call_module':
                    submod_class = modules[node.target].__class__
                    nn_class = getattr(torch.nn, submod_class.__name__)
                    if submod_class == nn_class:
                        self.assertEqual(len(node.args), 0)