def test_state_dict_equality(self):
        module = self._get_shared_params_transformer()
        ref_state_dict = module.state_dict()

        flat_module = FlattenParamsWrapper(module)
        flat_state_dict = flat_module.state_dict()

        assert objects_are_equal(ref_state_dict, flat_state_dict)
 def test_empty_module(self):
     module = self._get_empty_module()
     in_data = torch.rand(1)
     ref_out = module(in_data)
     module = FlattenParamsWrapper(module)
     assert len(list(module.parameters())) == 0
     assert len(module.state_dict()) == 0
     fpw_out = module(in_data)
     torch.testing.assert_allclose(ref_out, fpw_out)
 def test_flatten_nothing(self):
     module = self._get_transformer()
     ref_out = self._get_output(module)
     ref_state_dict = module.state_dict()
     for k, v in ref_state_dict.items():
         ref_state_dict[k] = v.clone()
     module = FlattenParamsWrapper(module, param_list=[[]])
     fpw_state_dict = module.state_dict()
     assert ref_state_dict.keys() == fpw_state_dict.keys()
     for k, v in ref_state_dict.items():
         torch.testing.assert_allclose(v, fpw_state_dict[k])
     fpw_out = self._get_output(module)
     torch.testing.assert_allclose(ref_out, fpw_out)
示例#4
0
    def test_state_dict_equality(self):
        """Test that unflattened state dict matches original (unwrapped) one."""
        modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()]
        for module in modules_to_test:
            ref_state_dict = module.state_dict()

            flat_module = FlattenParamsWrapper(module)
            flat_state_dict = flat_module.state_dict()

            assert (
                ref_state_dict.keys() == flat_state_dict.keys()
            ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}"
            assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"