def test_save_to_file_and_load_from_file_with_extensions( self, mock_add_extensions, mock_import_module, mock_installed_module, compress_save_file, pickle_only, schema): """Test that extensions are saved to the output config.yaml and they are also added when loading back the object.""" mock_installed_module.return_value = True schema_obj = schema() # Add extensions manually because if we use add_extensions_metadata # then no extensions will be added as the schema doesn't container_folder # any prefix. schema_obj._extensions = TestSerializationExtensions.EXTENSIONS obj = schema_obj() state = obj.get_state() with tempfile.TemporaryDirectory() as root_path: path = os.path.join(root_path, 'savefile.flambe') save_state_to_file(state, path, compress_save_file, pickle_only) list_files(path) if pickle_only: path += '.pkl' if compress_save_file: path += '.tar.gz' state_loaded = load_state_from_file(path) check_mapping_equivalence(state, state_loaded) check_mapping_equivalence(state._metadata, state_loaded._metadata) _ = Basic.load_from_path(path) mock_add_extensions.assert_called_once_with( TestSerializationExtensions.EXTENSIONS)
def test_save_to_file_and_load_from_file_roundtrip_complex_nontorch_root( self, complex_multi_layered_nontorch_root, pickle_only, compress_save_file): TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') old_obj = complex_multi_layered_nontorch_root(from_config=True) state = old_obj.get_state() with tempfile.TemporaryDirectory() as root_path: path = os.path.join(root_path, 'savefile.flambe') save_state_to_file(state, path, compress_save_file, pickle_only) list_files(path) if pickle_only: path += '.pkl' if compress_save_file: path += '.tar.gz' state_loaded = load_state_from_file(path) check_mapping_equivalence(state, state_loaded) check_mapping_equivalence(state._metadata, state_loaded._metadata) new_obj = complex_multi_layered_nontorch_root(from_config=True) int_state = new_obj.get_state() new_obj.load_state(state_loaded, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata) check_mapping_equivalence(int_state._metadata, state_loaded._metadata)
def test_save_to_file_and_load_from_file_roundtrip_complex( self, complex_multi_layered): TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') old_obj = complex_multi_layered(from_config=True) # Test that the current state is actually saved, for a # Component-only child of torch objects old_obj.child.child.child.x = 24 state = old_obj.get_state() with tempfile.TemporaryDirectory() as path: save_state_to_file(state, path) list_files(path) state_loaded = load_state_from_file(path) check_mapping_equivalence(state, state_loaded) # assert False new_obj = complex_multi_layered(from_config=True) new_obj.load_state(state_loaded, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)
def test_module_save_and_load_roundtrip_pytorch_only_bridge(self): a = BasicStateful.compile(x=3) b = 100 c = BasicStatefulTwo.compile(y=0) item = ComposableTorchStatefulTorchOnlyChild.compile(a=a, b=b, c=c) extra = torch.nn.Linear(2, 2) old_obj = Org.compile(item=item, extra=None) # x for a2 should be different from instance a a2 = BasicStateful.compile(x=4) b2 = 101 # y for c2 should be different from instance c c2 = BasicStatefulTwo.compile(y=1) item2 = ComposableTorchStatefulTorchOnlyChild.compile(a=a2, b=b2, c=c2) extra2 = torch.nn.Linear(2, 2) new_obj = Org.compile(item=item2, extra=None) with tempfile.TemporaryDirectory() as root_path: path = os.path.join(root_path, 'asavefile2.flambe') old_state = old_obj.get_state() save_state_to_file(old_state, path) new_state = load_state_from_file(path) new_obj.load_state(new_state) # save(old_obj, path) # new_obj = load(path) old_state_get = old_obj.get_state() new_state_get = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False) check_mapping_equivalence(new_state_get, old_state_get) check_mapping_equivalence(old_state_get._metadata, new_state_get._metadata, exclude_config=True)
def test_save_to_file_and_load_from_file_roundtrip_pytorch(self, alternating_nn_module_with_state): old_obj = alternating_nn_module_with_state(from_config=False) state = old_obj.get_state() with tempfile.TemporaryDirectory() as path: save_state_to_file(state, path) state = load_state_from_file(path) new_obj = alternating_nn_module_with_state(from_config=False) new_obj.load_state(state, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)
def test_save_to_file_and_load_from_file_roundtrip(self, basic_object): old_obj = basic_object(from_config=True) state = old_obj.get_state() with tempfile.TemporaryDirectory() as path: save_state_to_file(state, path) state = load_state_from_file(path) new_obj = basic_object(from_config=False) new_obj.load_state(state, strict=False) old_state = old_obj.get_state() new_state = new_obj.get_state() check_mapping_equivalence(new_state, old_state) check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=True)