def test_fit(deepspeech: DeepSpeech, generator: DataGenerator, config_path: str, alphabet_path: str, test_dir: str): # Test save best weights (overwrite the best result) weights_path = os.path.join(test_dir, 'weights_copy.hdf5') deepspeech.save(weights_path) distributed_weights = deepspeech.compiled_model.get_weights() model_checkpoint = deepspeech.callbacks[1] model_checkpoint.best_result = 0 model_checkpoint.best_weights_path = weights_path history = deepspeech.fit(train_generator=generator, dev_generator=generator, epochs=1, shuffle=False) assert type(history) == History # Test the returned model has `test_weights` deepspeech_weights = deepspeech.model.get_weights() new_deepspeech = DeepSpeech.construct(config_path, alphabet_path) new_deepspeech.load(model_checkpoint.best_weights_path) new_deepspeech_weights = new_deepspeech.model.get_weights() assert is_same(deepspeech_weights, new_deepspeech_weights) # Test that distributed model appropriate update weights new_distributed_weights = deepspeech.compiled_model.get_weights() assert is_same(distributed_weights, new_distributed_weights) shutil.rmtree('tests/checkpoints') os.remove('tests/weights_copy.hdf5')
def __run_program(configuration_line): """ Run DeepSpeech - save log file. """ configuration_file_path, *parameters = configuration_line.split('|') configuration = Configuration(configuration_file_path) __update_parameters(configuration, parameters) __create_experiment_dir(configuration) deepspeech_output = os.path.join(configuration.exp_dir, 'program.out') with open(deepspeech_output, 'w') as f: with redirect_stdout(f): ds = DeepSpeech(configuration) ds.train() ds.save()
import argparse import os from source.deepspeech import DeepSpeech from source.configuration import Configuration abspath = os.path.abspath(__file__) ROOT_DIR = os.path.dirname(abspath) os.chdir(ROOT_DIR) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--configuration', help='Experiment configuration') args = parser.parse_args() # Read configuration file config = Configuration(file_path=args.configuration) # Set up DeepSpeech object ds = DeepSpeech(config) # Model optimization ds.train() # Save whole deepspeech model ds.save()