def test_flat_state_dict(self):
        """Test that flat state dict can be reloaded and produces the same results."""
        for module_init_fn in self._get_module_init_fns():
            flat_module = FlattenParamsWrapper(module_init_fn())
            ref_output = self._get_output(flat_module)

            flat_state_dict = flat_module.flat_state_dict()

            new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
            new_module.load_state_dict(flat_state_dict)
            new_output = self._get_output(new_module)

            assert objects_are_equal(ref_output, new_output)
    def test_flat_state_dict(self):
        flat_module = self._get_shared_params_transformer()
        flat_module = FlattenParamsWrapper(flat_module)
        ref_output = self._get_output(flat_module)

        flat_state_dict = flat_module.flat_state_dict()

        new_module = self._get_shared_params_transformer(seed=1234)
        new_module = FlattenParamsWrapper(new_module)
        new_module.load_state_dict(flat_state_dict)
        new_output = self._get_output(new_module)

        assert objects_are_equal(ref_output, new_output)