コード例 #1
0
def test_fit(deepspeech: DeepSpeech, generator: DataGenerator,
             config_path: str, alphabet_path: str, model_dir: str):
    # Test save best weights (overwrite the best result)
    weights_path = os.path.join(model_dir, 'weights_copy.hdf5')
    deepspeech.save(weights_path)
    distributed_weights = deepspeech.distributed_model.get_weights()
    model_checkpoint = deepspeech.callbacks[2]
    model_checkpoint.best_result = 0
    model_checkpoint.best_weights_path = weights_path

    history = deepspeech.fit(train_generator=generator,
                             dev_generator=generator,
                             epochs=2,
                             shuffle=False)
    assert type(history) == History

    # Test the returned model has `test_weights`
    deepspeech_weights = deepspeech.model.get_weights()
    new_deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    new_deepspeech.load(model_checkpoint.best_weights_path)
    new_deepspeech_weights = new_deepspeech.model.get_weights()
    assert is_same(deepspeech_weights, new_deepspeech_weights)

    # Test that distributed model appropriate update weights
    new_distributed_weights = deepspeech.distributed_model.get_weights()
    assert is_same(distributed_weights, new_distributed_weights)
コード例 #2
0
def test_get_labels_and_get_transcripts(deepspeech: DeepSpeech):
    bad_transcripts = ['to jest je$st!', 'test']
    correct_transcripts = ['to jest jest', 'test']
    labels = deepspeech.get_labels(bad_transcripts)
    assert labels.dtype == np.int64
    assert labels.shape == (2, 12)
    assert labels[1, 4] == deepspeech.alphabet.blank_token
    transformed_transcripts = deepspeech.get_transcripts(labels)
    assert transformed_transcripts == correct_transcripts
コード例 #3
0
def test_save_load(deepspeech: DeepSpeech, config: Configuration,
                   config_path: str, alphabet_path: str, model_dir: str):
    weights_path = os.path.join(model_dir, 'weights.hdf5')
    model_weights = deepspeech.model.get_weights()
    deepspeech.save(weights_path)

    new_deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    new_deepspeech.model = deepspeech.get_model(**config.model,
                                                is_gpu=False,
                                                random_state=123)
    new_model_weights = new_deepspeech.model.get_weights()
    assert not is_same(model_weights, new_model_weights)

    new_deepspeech.load(weights_path)
    new_model_weights = new_deepspeech.model.get_weights()
    assert is_same(model_weights, new_model_weights)
コード例 #4
0
def test_create_generator_from_audio_files(deepspeech: DeepSpeech):
    generator = deepspeech.create_generator(file_path='tests/data/audio.csv',
                                            source='from_audio_files',
                                            batch_size=2)
    assert len(generator) == 2
    X, y = generator[0]
    assert X.shape == (2, 739, 26)
    assert y.shape == (2, 206)
コード例 #5
0
def test_create_generator_from_prepared_features(deepspeech: DeepSpeech):
    generator = deepspeech.create_generator(
        file_path='tests/data/features.hdf5',
        source='from_prepared_features',
        batch_size=2)
    assert len(generator) == 6
    X, y = generator[0]
    assert X.shape == (2, 93, 26)
    assert y.shape == (2, 39)
コード例 #6
0
def test_trainable():
    gpus = deepspeech.get_available_gpus(
    )  # Support both Multi and Single-GPU tests
    base_configuration = DeepSpeech.get_configuration(
        'tests/models/base/configuration.yaml')
    base_configuration.model.pop('name')
    base_model = deepspeech_custom(is_gpu=len(gpus) > 0,
                                   **base_configuration.model)
    fname = 'weights.hdf5'
    base_model.save_weights(fname)

    extended_configuration = DeepSpeech.get_configuration(
        'tests/models/extended/configuration.yaml')
    extended_configuration.model.pop('name')
    extended_model = deepspeech_custom(is_gpu=len(gpus) > 0,
                                       **extended_configuration.model)
    weights_before_training = extended_model.get_weights()

    assert all(not extended_model.get_layer(name).trainable
               for name in ['base_1', 'base_2', 'base_3'])
    assert all(
        extended_model.get_layer(name).trainable
        for name in ['extension_1', 'extension_2'])
    assert all(not is_same(
        base_model.get_layer(name).get_weights(),
        extended_model.get_layer(name).get_weights())
               for name in ['base_1', 'base_2', 'base_3'])

    loss = DeepSpeech.get_loss()
    optimizer = DeepSpeech.get_optimizer(**extended_configuration.optimizer)
    parallel_model = DeepSpeech.distribute_model(extended_model, gpus)
    DeepSpeech.compile_model(parallel_model, optimizer, loss)
    extended_model.load_weights(fname, by_name=True)

    assert all(
        is_same(
            base_model.get_layer(name).get_weights(),
            extended_model.get_layer(name).get_weights())
        for name in ['base_1', 'base_2', 'base_3'])

    for i in range(10):  # Dummy training (10 epochs / 10 batch_size)
        X = np.random.rand(10, 100, 80)
        y = np.random.randint(0, 35, size=[10, 20], dtype=np.int32)
        parallel_model.train_on_batch(X, y)

    assert all(
        is_same(
            base_model.get_layer(name).get_weights(),
            extended_model.get_layer(name).get_weights())
        for name in ['base_1', 'base_2', 'base_3'
                     ]), "Freezed layers have to be unchangeable."
    assert not is_same(
        weights_before_training,
        extended_model.get_weights()), "The base model updates weights."
    assert is_close(extended_model.predict(X), parallel_model.predict(
        X)), "The results are the same (compiled model)."
    # assert is_same(extended_model.get_weights(), compiled_model.get_weights())    # Weights can not be compared (order changed)
    os.remove(fname)
コード例 #7
0
ファイル: utils.py プロジェクト: wansuiye09/DeepSpeech-Keras
def load(name: str):
    from deepspeech import DeepSpeech
    if os.path.isdir(name):
        model_dir = name
    else:
        model_dir = get_pretrained_model_dir(name)

    config_path = os.path.join(model_dir, 'configuration.yaml')
    alphabet_path = os.path.join(model_dir, 'alphabet.txt')
    weights_path = os.path.join(model_dir, 'weights.hdf5')
    assert os.path.isfile(config_path), 'The config file required in the directory'
    assert os.path.isfile(alphabet_path), 'The alphabet file required in the directory'

    deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    if os.path.isfile(weights_path):
        deepspeech.load(weights_path)
    return deepspeech
コード例 #8
0
def config(config_path) -> Configuration:
    return DeepSpeech.get_configuration(config_path)
コード例 #9
0
def test_compile_model(config: Configuration):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    optimizer = DeepSpeech.get_optimizer(**config.optimizer)
    loss = DeepSpeech.get_loss()
    compiled_model = DeepSpeech.compile_model(model, optimizer, loss, gpus=[])
    assert compiled_model._is_compiled
コード例 #10
0
def test_get_callbacks(model_dir: str, config: Configuration):
    callbacks = DeepSpeech.get_callbacks(home_dir=model_dir,
                                         configurations=config.callbacks)
    assert len(callbacks) == 6
コード例 #11
0
def test_get_decoder(config: Configuration, alphabet: Alphabet):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    decoder = DeepSpeech.get_decoder(alphabet=alphabet,
                                     model=model,
                                     **config.decoder)
    assert callable(decoder)
コード例 #12
0
def test_get_features_extractor(config: Configuration):
    features_extractor = DeepSpeech.get_features_extractor(
        **config.features_extractor)
    assert type(features_extractor) == FeaturesExtractor
コード例 #13
0
def test_get_model(config: Configuration):
    model = DeepSpeech.get_model(**config.model, is_gpu=False)
    assert type(model) == Model
    new_model = DeepSpeech.get_model(**config.model, is_gpu=False)
    assert is_same(model.get_weights(),
                   new_model.get_weights())  # Test random seed
コード例 #14
0
def generator(deepspeech: DeepSpeech) -> DataGenerator:
    return deepspeech.create_generator(file_path='tests/data/features.hdf5',
                                       source='from_prepared_features',
                                       batch_size=2)
コード例 #15
0
def test_get_features(deepspeech: DeepSpeech, audio_file_paths: List[str]):
    features = deepspeech.get_features(audio_file_paths)
    assert features.shape == (2, 739, 26)
コード例 #16
0
def alphabet(alphabet_path: str) -> Alphabet:
    return DeepSpeech.get_alphabet(alphabet_path)
コード例 #17
0
def deepspeech(config_path: str, alphabet_path: str) -> DeepSpeech:
    return DeepSpeech.construct(config_path, alphabet_path)
コード例 #18
0
import sys
import scipy.io.wavfile as wav
from deepspeech import DeepSpeech

ds = DeepSpeech(sys.argv[1], 26, 9)
fs, audio = wav.read(sys.argv[2])
print ds.stt(audio, fs)