Ejemplo n.º 1
0
def test_save_load(deepspeech: DeepSpeech, config: Configuration,
                   config_path: str, alphabet_path: str, model_dir: str):
    weights_path = os.path.join(model_dir, 'weights.hdf5')
    model_weights = deepspeech.model.get_weights()
    deepspeech.save(weights_path)

    new_deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    new_deepspeech.model = deepspeech.get_model(**config.model,
                                                is_gpu=False,
                                                random_state=123)
    new_model_weights = new_deepspeech.model.get_weights()
    assert not is_same(model_weights, new_model_weights)

    new_deepspeech.load(weights_path)
    new_model_weights = new_deepspeech.model.get_weights()
    assert is_same(model_weights, new_model_weights)
Ejemplo n.º 2
0
def test_compile_model(config: Configuration):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    optimizer = DeepSpeech.get_optimizer(**config.optimizer)
    loss = DeepSpeech.get_loss()
    compiled_model = DeepSpeech.compile_model(model, optimizer, loss, gpus=[])
    assert compiled_model._is_compiled
Ejemplo n.º 3
0
def test_get_decoder(config: Configuration, alphabet: Alphabet):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    decoder = DeepSpeech.get_decoder(alphabet=alphabet,
                                     model=model,
                                     **config.decoder)
    assert callable(decoder)
Ejemplo n.º 4
0
def test_get_model(config: Configuration):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    assert type(model) == Model
    new_model = DeepSpeech.get_model(**config.model, is_gpu=False)
    assert is_same(model.get_weights(),
                   new_model.get_weights())  # Test random seed