Ejemplo n.º 1
0
    def test_unflatten_params(self):
        for module_init_fn in self._get_module_init_fns():
            module = FlattenParamsWrapper(module_init_fn())
            buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}

            def clone_state_dict():
                return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())

            ref_flat_param = module.flat_param.clone()
            with module.unflatten_params():
                ref_state_dict = clone_state_dict()
            assert not torch.all(ref_flat_param == 0)

            # confirm that unflatten_params reflects values from new_flat_param
            new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
            with module.unflatten_params(flat_param=new_flat_param):
                new_state_dict = clone_state_dict()
                assert new_state_dict.keys() == ref_state_dict.keys()
                for k, v in new_state_dict.items():
                    if k in buffers:  # buffers are not changed
                        torch.testing.assert_allclose(v, ref_state_dict[k])
                    else:  # params reflect new_flat_param value
                        assert torch.all(v == 42.0)

            # after context manager exits, we go back to previous (reference) state
            torch.testing.assert_allclose(module.flat_param, ref_flat_param)
            with module.unflatten_params():
                ref_state_dict2 = clone_state_dict()
                assert objects_are_equal(ref_state_dict, ref_state_dict2)

            # if we load the new_state_dict, then the flat param should match new_flat_param
            module.load_state_dict(new_state_dict)
            torch.testing.assert_allclose(module.flat_param, new_flat_param)
    def test_load_state_dict(self):
        """Test that original (unwrapped) state_dict can be loaded in wrapped module."""
        for module_init_fn in self._get_module_init_fns():
            module = module_init_fn()
            ref_state_dict = module.state_dict()
            ref_output = self._get_output(module)

            module = module_init_fn(seed=1234)
            flat_module = FlattenParamsWrapper(module)

            # This should work without the unflatten_params context manager
            flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)

            # And it should work with the context manager too
            with flat_module.unflatten_params():
                flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)