def test_partial_flattening(self):
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

        params_to_flatten = list(module.encoder.layers[1].parameters()) + list(
            module.decoder.layers[0].parameters())
        num_params_to_flatten = sum(p.numel() for p in params_to_flatten)

        module = FlattenParamsWrapper(module, param_list=params_to_flatten)
        assert module.flat_param.numel() == num_params_to_flatten
        assert sum(p.numel() for p in module.parameters()) == num_params

        # flattened parameters are removed
        assert len(list(module.encoder.layers[1].parameters())) == 0
        assert len(list(module.decoder.layers[0].parameters())) == 0

        # non-flattened parameters remain
        assert len(list(module.encoder.layers[0].parameters())) > 0
        assert len(list(module.decoder.layers[1].parameters())) > 0

        # test that changing the module dtype works properly
        orig_dtype = params_to_flatten[0].dtype
        new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
        assert module.flat_param.dtype == orig_dtype
        assert all(p.dtype == orig_dtype
                   for p in module.encoder.layers[0].parameters())
        module = module.to(dtype=new_dtype)
        assert module.flat_param.dtype == new_dtype
        assert all(p.dtype == new_dtype
                   for p in module.encoder.layers[0].parameters())
    def _test_num_params(self, module):
        ref_num_params = sum(p.numel() for p in module.parameters())

        flat_module = FlattenParamsWrapper(module)
        flat_num_params = sum(p.numel() for p in flat_module.parameters())

        assert ref_num_params == flat_num_params
        assert flat_num_params == flat_module.flat_param.numel()
 def test_empty_module(self):
     module = self._get_empty_module()
     in_data = torch.rand(1)
     ref_out = module(in_data)
     module = FlattenParamsWrapper(module)
     assert len(list(module.parameters())) == 0
     assert len(module.state_dict()) == 0
     fpw_out = module(in_data)
     torch.testing.assert_allclose(ref_out, fpw_out)
    def test_two_flattening_group(self):
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

        params_to_flatten1 = list(
            module.encoder.layers[1].parameters()) + list(
                module.decoder.layers[0].parameters())
        params_to_flatten2 = list(
            module.encoder.layers[0].parameters()) + list(
                module.decoder.layers[1].parameters())
        num_params_to_flatten1 = sum(p.numel() for p in params_to_flatten1)
        num_params_to_flatten2 = sum(p.numel() for p in params_to_flatten2)

        module = FlattenParamsWrapper(
            module, param_list=[params_to_flatten1, params_to_flatten2])
        assert module.flat_params[0].numel() == num_params_to_flatten1
        assert module.flat_params[1].numel() == num_params_to_flatten2
        assert sum(p.numel() for p in module.parameters()) == num_params