def test_partial_flattening(self):
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

        params_to_flatten = list(module.encoder.layers[1].parameters()) + list(
            module.decoder.layers[0].parameters())
        num_params_to_flatten = sum(p.numel() for p in params_to_flatten)

        module = FlattenParamsWrapper(module, param_list=params_to_flatten)
        assert module.flat_param.numel() == num_params_to_flatten
        assert sum(p.numel() for p in module.parameters()) == num_params

        # flattened parameters are removed
        assert len(list(module.encoder.layers[1].parameters())) == 0
        assert len(list(module.decoder.layers[0].parameters())) == 0

        # non-flattened parameters remain
        assert len(list(module.encoder.layers[0].parameters())) > 0
        assert len(list(module.decoder.layers[1].parameters())) > 0

        # test that changing the module dtype works properly
        orig_dtype = params_to_flatten[0].dtype
        new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
        assert module.flat_param.dtype == orig_dtype
        assert all(p.dtype == orig_dtype
                   for p in module.encoder.layers[0].parameters())
        module = module.to(dtype=new_dtype)
        assert module.flat_param.dtype == new_dtype
        assert all(p.dtype == new_dtype
                   for p in module.encoder.layers[0].parameters())
    def _test_num_params(self, module):
        ref_num_params = sum(p.numel() for p in module.parameters())

        flat_module = FlattenParamsWrapper(module)
        flat_num_params = sum(p.numel() for p in flat_module.parameters())

        assert ref_num_params == flat_num_params
        assert flat_num_params == flat_module.flat_param.numel()
    def test_state_dict_equality(self):
        module = self._get_shared_params_transformer()
        ref_state_dict = module.state_dict()

        flat_module = FlattenParamsWrapper(module)
        flat_state_dict = flat_module.state_dict()

        assert objects_are_equal(ref_state_dict, flat_state_dict)
 def test_empty_module(self):
     module = self._get_empty_module()
     in_data = torch.rand(1)
     ref_out = module(in_data)
     module = FlattenParamsWrapper(module)
     assert len(list(module.parameters())) == 0
     assert len(module.state_dict()) == 0
     fpw_out = module(in_data)
     torch.testing.assert_allclose(ref_out, fpw_out)
    def test_load_state_dict(self):
        module = self._get_shared_params_transformer()
        ref_state_dict = module.state_dict()
        ref_output = self._get_output(module)

        module = self._get_shared_params_transformer(seed=1234)
        flat_module = FlattenParamsWrapper(module)
        flat_module.load_state_dict(ref_state_dict)
        flat_output = self._get_output(flat_module)

        assert objects_are_equal(ref_output, flat_output)
 def test_flatten_nothing(self):
     module = self._get_transformer()
     ref_out = self._get_output(module)
     ref_state_dict = module.state_dict()
     for k, v in ref_state_dict.items():
         ref_state_dict[k] = v.clone()
     module = FlattenParamsWrapper(module, param_list=[[]])
     fpw_state_dict = module.state_dict()
     assert ref_state_dict.keys() == fpw_state_dict.keys()
     for k, v in ref_state_dict.items():
         torch.testing.assert_allclose(v, fpw_state_dict[k])
     fpw_out = self._get_output(module)
     torch.testing.assert_allclose(ref_out, fpw_out)
Exemplo n.º 7
0
    def test_state_dict_equality(self):
        """Test that unflattened state dict matches original (unwrapped) one."""
        modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()]
        for module in modules_to_test:
            ref_state_dict = module.state_dict()

            flat_module = FlattenParamsWrapper(module)
            flat_state_dict = flat_module.state_dict()

            assert (
                ref_state_dict.keys() == flat_state_dict.keys()
            ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}"
            assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"
Exemplo n.º 8
0
    def test_unflatten_params(self):
        for module_init_fn in self._get_module_init_fns():
            module = FlattenParamsWrapper(module_init_fn())
            buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}

            def clone_state_dict():
                return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())

            ref_flat_param = module.flat_param.clone()
            with module.unflatten_params():
                ref_state_dict = clone_state_dict()
            assert not torch.all(ref_flat_param == 0)

            # confirm that unflatten_params reflects values from new_flat_param
            new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
            with module.unflatten_params(flat_param=new_flat_param):
                new_state_dict = clone_state_dict()
                assert new_state_dict.keys() == ref_state_dict.keys()
                for k, v in new_state_dict.items():
                    if k in buffers:  # buffers are not changed
                        torch.testing.assert_allclose(v, ref_state_dict[k])
                    else:  # params reflect new_flat_param value
                        assert torch.all(v == 42.0)

            # after context manager exits, we go back to previous (reference) state
            torch.testing.assert_allclose(module.flat_param, ref_flat_param)
            with module.unflatten_params():
                ref_state_dict2 = clone_state_dict()
                assert objects_are_equal(ref_state_dict, ref_state_dict2)

            # if we load the new_state_dict, then the flat param should match new_flat_param
            module.load_state_dict(new_state_dict)
            torch.testing.assert_allclose(module.flat_param, new_flat_param)
Exemplo n.º 9
0
    def _get_nested_flat_module(self, seed=0):
        module = torch.nn.Sequential(
            FlattenParamsWrapper(
                torch.nn.Sequential(torch.nn.Linear(4, 8), FlattenParamsWrapper(torch.nn.Linear(8, 8)))
            ),
            FlattenParamsWrapper(torch.nn.Sequential(FlattenParamsWrapper(torch.nn.Linear(8, 16)))),
            FlattenParamsWrapper(torch.nn.Linear(16, 4)),
        )

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            return (torch.rand(8, 4).to(device=device, dtype=dtype),)

        module.get_input = get_input
        return module
    def test_two_flattening_group(self):
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

        params_to_flatten1 = list(
            module.encoder.layers[1].parameters()) + list(
                module.decoder.layers[0].parameters())
        params_to_flatten2 = list(
            module.encoder.layers[0].parameters()) + list(
                module.decoder.layers[1].parameters())
        num_params_to_flatten1 = sum(p.numel() for p in params_to_flatten1)
        num_params_to_flatten2 = sum(p.numel() for p in params_to_flatten2)

        module = FlattenParamsWrapper(
            module, param_list=[params_to_flatten1, params_to_flatten2])
        assert module.flat_params[0].numel() == num_params_to_flatten1
        assert module.flat_params[1].numel() == num_params_to_flatten2
        assert sum(p.numel() for p in module.parameters()) == num_params
Exemplo n.º 11
0
    def test_shared_params_pnorm_after_step(self):
        # incorrect parameter sharing is likely to cause problems after an
        # optimization step
        module = self._get_shared_params_transformer()
        ref_pnorm_after_step = self._get_pnorm_after_step(module)

        module = self._get_shared_params_transformer()  # recreate
        flat_module = FlattenParamsWrapper(module)
        flat_pnorm_after_step = self._get_pnorm_after_step(flat_module)

        torch.testing.assert_allclose(ref_pnorm_after_step, flat_pnorm_after_step)
    def test_flat_state_dict(self):
        """Test that flat state dict can be reloaded and produces the same results."""
        for module_init_fn in self._get_module_init_fns():
            flat_module = FlattenParamsWrapper(module_init_fn())
            ref_output = self._get_output(flat_module)

            flat_state_dict = flat_module.flat_state_dict()

            new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
            new_module.load_state_dict(flat_state_dict)
            new_output = self._get_output(new_module)

            assert objects_are_equal(ref_output, new_output)
    def test_flat_state_dict(self):
        flat_module = self._get_shared_params_transformer()
        flat_module = FlattenParamsWrapper(flat_module)
        ref_output = self._get_output(flat_module)

        flat_state_dict = flat_module.flat_state_dict()

        new_module = self._get_shared_params_transformer(seed=1234)
        new_module = FlattenParamsWrapper(new_module)
        new_module.load_state_dict(flat_state_dict)
        new_output = self._get_output(new_module)

        assert objects_are_equal(ref_output, new_output)
    def test_load_state_dict(self):
        """Test that original (unwrapped) state_dict can be loaded in wrapped module."""
        for module_init_fn in self._get_module_init_fns():
            module = module_init_fn()
            ref_state_dict = module.state_dict()
            ref_output = self._get_output(module)

            module = module_init_fn(seed=1234)
            flat_module = FlattenParamsWrapper(module)

            # This should work without the unflatten_params context manager
            flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)

            # And it should work with the context manager too
            with flat_module.unflatten_params():
                flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)
    def _test_output(self, module):
        ref_output = self._get_output(module)

        flat_module = FlattenParamsWrapper(module)
        flat_output = self._get_output(flat_module)
        assert objects_are_equal(ref_output, flat_output)