예제 #1
0
def main(args):
    deepspeech = DeepSpeech.construct(config_path=CONFIG_PATH, alphabet_path=ALPHABET_PATH)
    if args.pretrained_weights:
        deepspeech.load(args.pretrained_weights)
    train_generator, dev_generator = create_generators(deepspeech, args)
    deepspeech.fit(train_generator, dev_generator, epochs=args.epochs, shuffle=False)
    deepspeech.save(WEIGHTS_PATH)
예제 #2
0
def setup_deepspeech(config_path: str,
                     alphabet_path: str,
                     pretrained_weights: str = '') -> DeepSpeech:
    deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    if pretrained_weights:
        deepspeech.load(pretrained_weights)
    return deepspeech
예제 #3
0
def test_fit(deepspeech: DeepSpeech, generator: DataGenerator,
             config_path: str, alphabet_path: str, test_dir: str):
    # Test save best weights (overwrite the best result)
    weights_path = os.path.join(test_dir, 'weights_copy.hdf5')
    deepspeech.save(weights_path)
    distributed_weights = deepspeech.compiled_model.get_weights()
    model_checkpoint = deepspeech.callbacks[1]
    model_checkpoint.best_result = 0
    model_checkpoint.best_weights_path = weights_path

    history = deepspeech.fit(train_generator=generator,
                             dev_generator=generator,
                             epochs=1,
                             shuffle=False)
    assert type(history) == History

    # Test the returned model has `test_weights`
    deepspeech_weights = deepspeech.model.get_weights()
    new_deepspeech = DeepSpeech.construct(config_path, alphabet_path)
    new_deepspeech.load(model_checkpoint.best_weights_path)
    new_deepspeech_weights = new_deepspeech.model.get_weights()
    assert is_same(deepspeech_weights, new_deepspeech_weights)

    # Test that distributed model appropriate update weights
    new_distributed_weights = deepspeech.compiled_model.get_weights()
    assert is_same(distributed_weights, new_distributed_weights)
    shutil.rmtree('tests/checkpoints')
    os.remove('tests/weights_copy.hdf5')
예제 #4
0
def load_extended_model(config_path, alphabet_path, pretrained_weights):
    deepspeech = DeepSpeech.construct(config_path=config_path,
                                      alphabet_path=alphabet_path)

    freeze(deepspeech.model)
    gpus = get_available_gpus()
    config = Configuration(config_path)
    extended_model = create_extended_model(deepspeech.model,
                                           config,
                                           is_gpu=len(gpus) > 0)

    optimizer = DeepSpeech.get_optimizer(**config.optimizer)
    loss = DeepSpeech.get_loss()
    gpus = get_available_gpus()
    deepspeech.model = extended_model
    deepspeech.compiled_model = DeepSpeech.compile_model(
        extended_model, optimizer, loss, gpus)
    deepspeech.load(pretrained_weights)
    return deepspeech
예제 #5
0
def main(args):
    deepspeech = DeepSpeech.construct(config_path=CONFIG_PATH, alphabet_path=ALPHABET_PATH)
    if args.pretrained_weights:
        deepspeech.load(args.pretrained_weights)

    freeze(deepspeech.model)
    gpus = get_available_gpus()
    config = Configuration(CONFIG_PATH)
    extended_model = create_extended_model(deepspeech.model, config, is_gpu=len(gpus) > 0)

    optimizer = DeepSpeech.get_optimizer(**config.optimizer)
    loss = DeepSpeech.get_loss()
    gpus = get_available_gpus()
    deepspeech.model = extended_model
    deepspeech.compiled_model = DeepSpeech.compile_model(extended_model, optimizer, loss, gpus)

    train_generator, dev_generator = create_generators(deepspeech, args)
    deepspeech.fit(train_generator, dev_generator, epochs=args.epochs, shuffle=False)
    deepspeech.save(WEIGHTS_PATH)
예제 #6
0
def main(args):
    deepspeech = DeepSpeech.construct(config_path=config_path,
                                      alphabet_path=alphabet_path)
    if args.pretrained_weights:
        deepspeech.load(args.pretrained_weights)

    train_generator = deepspeech.create_generator(
        args.train,
        batch_size=args.batch_size,
        source=args.source,
        shuffle_after_epoch=args.shuffle_after_epoch)
    dev_generator = deepspeech.create_generator(args.dev,
                                                batch_size=args.batch_size,
                                                source=args.source)

    deepspeech.fit(train_generator,
                   dev_generator,
                   epochs=args.epochs,
                   shuffle=False)
    deepspeech.save(weights_path)
예제 #7
0
def deepspeech(config_path: str, alphabet_path: str) -> DeepSpeech:
    return DeepSpeech.construct(config_path, alphabet_path)