コード例 #1
0
def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(
        checkpoint_storage_test_parameterized_setup):
    saved = checkpoint_storage_test_parameterized_setup
    _checkpoint_storage.save(saved, pytest.checkpoint_path)

    loaded = _checkpoint_storage.load(pytest.checkpoint_path)
    assert _equals(saved, loaded)
コード例 #2
0
def test_checkpoint_storage_saved_dict_matches_loaded(
        checkpoint_storage_test_parameterized_setup):
    to_save = checkpoint_storage_test_parameterized_setup[0]
    key_arg = checkpoint_storage_test_parameterized_setup[1]
    expected = checkpoint_storage_test_parameterized_setup[2]
    _checkpoint_storage.save(to_save, pytest.checkpoint_path)
    loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg)
    assert _equals(loaded, expected)
    assert _numpy_types(loaded)
コード例 #3
0
def test_checkpoint_storage_saving_multiple_dimension_tensors(
        checkpoint_storage_test_parameterized_setup):
    tensor_dict = checkpoint_storage_test_parameterized_setup[0]
    tensor_name = checkpoint_storage_test_parameterized_setup[1]

    _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)

    loaded = _checkpoint_storage.load(pytest.checkpoint_path)
    assert isinstance(loaded[tensor_name], np.ndarray)
    assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
コード例 #4
0
def test_checkpoint_storage_saving_tensor_datatype(
        checkpoint_storage_test_parameterized_setup):
    tensor_dict = checkpoint_storage_test_parameterized_setup[0]
    tensor_name = checkpoint_storage_test_parameterized_setup[1]
    tensor_dtype = checkpoint_storage_test_parameterized_setup[2]
    np_dtype = checkpoint_storage_test_parameterized_setup[3]

    _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)

    loaded = _checkpoint_storage.load(pytest.checkpoint_path)
    assert isinstance(loaded[tensor_name], np.ndarray)
    assert tensor_dict[tensor_name].dtype == tensor_dtype
    assert loaded[tensor_name].dtype == np_dtype
    assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
コード例 #5
0
def test_checkpoint_storage_for_custom_user_dict_succeeds(
        checkpoint_storage_test_setup):
    custom_class = _CustomClass()
    user_dict = {
        'tensor1': torch.tensor(np.arange(100), dtype=torch.float32),
        'custom_class': custom_class
    }

    pickled_bytes = binascii.b2a_hex(pickle.dumps(user_dict))
    to_save = {
        'a': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32),
        'user_dict': pickled_bytes
    }
    _checkpoint_storage.save(to_save, pytest.checkpoint_path)

    loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path)
    assert (loaded_dict['a'] == to_save['a'].numpy()).all()
    loaded_obj = pickle.loads(binascii.a2b_hex(loaded_dict['user_dict']))

    assert torch.all(loaded_obj['tensor1'].eq(user_dict['tensor1']))
    assert loaded_obj['custom_class'] == custom_class
コード例 #6
0
def test_checkpoint_storage_load_file_that_does_not_exist_fails(
        checkpoint_storage_test_setup):
    with pytest.raises(Exception):
        _checkpoint_storage.load(pytest.checkpoint_path)