def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds( checkpoint_storage_test_parameterized_setup): saved = checkpoint_storage_test_parameterized_setup _checkpoint_storage.save(saved, pytest.checkpoint_path) loaded = _checkpoint_storage.load(pytest.checkpoint_path) assert _equals(saved, loaded)
def test_checkpoint_storage_saved_dict_matches_loaded( checkpoint_storage_test_parameterized_setup): to_save = checkpoint_storage_test_parameterized_setup[0] key_arg = checkpoint_storage_test_parameterized_setup[1] expected = checkpoint_storage_test_parameterized_setup[2] _checkpoint_storage.save(to_save, pytest.checkpoint_path) loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg) assert _equals(loaded, expected) assert _numpy_types(loaded)
def test_checkpoint_storage_saving_multiple_dimension_tensors( checkpoint_storage_test_parameterized_setup): tensor_dict = checkpoint_storage_test_parameterized_setup[0] tensor_name = checkpoint_storage_test_parameterized_setup[1] _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) loaded = _checkpoint_storage.load(pytest.checkpoint_path) assert isinstance(loaded[tensor_name], np.ndarray) assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
def test_checkpoint_storage_saving_tensor_datatype( checkpoint_storage_test_parameterized_setup): tensor_dict = checkpoint_storage_test_parameterized_setup[0] tensor_name = checkpoint_storage_test_parameterized_setup[1] tensor_dtype = checkpoint_storage_test_parameterized_setup[2] np_dtype = checkpoint_storage_test_parameterized_setup[3] _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) loaded = _checkpoint_storage.load(pytest.checkpoint_path) assert isinstance(loaded[tensor_name], np.ndarray) assert tensor_dict[tensor_name].dtype == tensor_dtype assert loaded[tensor_name].dtype == np_dtype assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
def test_checkpoint_storage_for_custom_user_dict_succeeds( checkpoint_storage_test_setup): custom_class = _CustomClass() user_dict = { 'tensor1': torch.tensor(np.arange(100), dtype=torch.float32), 'custom_class': custom_class } pickled_bytes = binascii.b2a_hex(pickle.dumps(user_dict)) to_save = { 'a': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), 'user_dict': pickled_bytes } _checkpoint_storage.save(to_save, pytest.checkpoint_path) loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) assert (loaded_dict['a'] == to_save['a'].numpy()).all() loaded_obj = pickle.loads(binascii.a2b_hex(loaded_dict['user_dict'])) assert torch.all(loaded_obj['tensor1'].eq(user_dict['tensor1'])) assert loaded_obj['custom_class'] == custom_class
def test_checkpoint_storage_load_file_that_does_not_exist_fails( checkpoint_storage_test_setup): with pytest.raises(Exception): _checkpoint_storage.load(pytest.checkpoint_path)