def test_state_dict_equality(self): module = self._get_shared_params_transformer() ref_state_dict = module.state_dict() flat_module = FlattenParamsWrapper(module) flat_state_dict = flat_module.state_dict() assert objects_are_equal(ref_state_dict, flat_state_dict)
def test_empty_module(self): module = self._get_empty_module() in_data = torch.rand(1) ref_out = module(in_data) module = FlattenParamsWrapper(module) assert len(list(module.parameters())) == 0 assert len(module.state_dict()) == 0 fpw_out = module(in_data) torch.testing.assert_allclose(ref_out, fpw_out)
def test_flatten_nothing(self): module = self._get_transformer() ref_out = self._get_output(module) ref_state_dict = module.state_dict() for k, v in ref_state_dict.items(): ref_state_dict[k] = v.clone() module = FlattenParamsWrapper(module, param_list=[[]]) fpw_state_dict = module.state_dict() assert ref_state_dict.keys() == fpw_state_dict.keys() for k, v in ref_state_dict.items(): torch.testing.assert_allclose(v, fpw_state_dict[k]) fpw_out = self._get_output(module) torch.testing.assert_allclose(ref_out, fpw_out)
def test_state_dict_equality(self): """Test that unflattened state dict matches original (unwrapped) one.""" modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()] for module in modules_to_test: ref_state_dict = module.state_dict() flat_module = FlattenParamsWrapper(module) flat_state_dict = flat_module.state_dict() assert ( ref_state_dict.keys() == flat_state_dict.keys() ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}" assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"