def run_deep_speech(_): """Run deep speech training and eval loop.""" tf.set_random_seed(flags_obj.seed) # Data preprocessing tf.logging.info("Data preprocessing...") train_speech_dataset = generate_dataset(flags_obj.train_data_dir) eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) # Number of label classes. Label string is "[a-z]' -" num_classes = len(train_speech_dataset.speech_labels) # Use distribution strategy for multi-gpu training num_gpus = flags_core.get_num_gpus(flags_obj) distribution_strategy = distribution_utils.get_distribution_strategy( num_gpus=num_gpus) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy) estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=flags_obj.model_dir, config=run_config, params={ "num_classes": num_classes, }) # Benchmark logging run_params = { "batch_size": flags_obj.batch_size, "train_epochs": flags_obj.train_epochs, "rnn_hidden_size": flags_obj.rnn_hidden_size, "rnn_hidden_layers": flags_obj.rnn_hidden_layers, "rnn_type": flags_obj.rnn_type, "is_bidirectional": flags_obj.is_bidirectional, "use_bias": flags_obj.use_bias } dataset_name = "LibriSpeech" benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info("deep_speech", dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) per_replica_batch_size = distribution_utils.per_replica_batch_size( flags_obj.batch_size, num_gpus) def input_fn_train(): return dataset.input_fn(per_replica_batch_size, train_speech_dataset) def input_fn_eval(): return dataset.input_fn(per_replica_batch_size, eval_speech_dataset) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info("Starting a training cycle: %d/%d", cycle_index + 1, total_training_cycle) # Perform batch_wise dataset shuffling train_speech_dataset.entries = dataset.batch_wise_dataset_shuffle( train_speech_dataset.entries, cycle_index, flags_obj.sortagrad, flags_obj.batch_size) estimator.train(input_fn=input_fn_train, hooks=train_hooks, max_steps=flags_obj.max_train_steps) if flags_obj.skip_eval: break # Evaluation tf.logging.info("Starting to evaluate...") eval_results = evaluate_model(estimator, eval_speech_dataset.speech_labels, eval_speech_dataset.entries, input_fn_eval) # Log the WER and CER results. benchmark_logger.log_evaluation_result(eval_results) tf.logging.info("Iteration {}: WER = {:.2f}, CER = {:.2f}".format( cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY])) # If some evaluation threshold is met if model_helpers.past_stop_threshold(flags_obj.wer_threshold, eval_results[_WER_KEY]): break
def run_deep_speech(_): """Run deep speech training and eval loop.""" tf.set_random_seed(flags_obj.seed) # Data preprocessing tf.logging.info("Data preprocessing...") train_speech_dataset = generate_dataset(flags_obj.train_data_dir) eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) # Number of label classes. Label string is "[a-z]' -" num_classes = len(train_speech_dataset.speech_labels) # Use distribution strategy for multi-gpu training num_gpus = flags_core.get_num_gpus(flags_obj) distribution_strategy = distribution_utils.get_distribution_strategy(num_gpus) run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy) estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=flags_obj.model_dir, config=run_config, params={ "num_classes": num_classes, } ) # Benchmark logging run_params = { "batch_size": flags_obj.batch_size, "train_epochs": flags_obj.train_epochs, "rnn_hidden_size": flags_obj.rnn_hidden_size, "rnn_hidden_layers": flags_obj.rnn_hidden_layers, "rnn_type": flags_obj.rnn_type, "is_bidirectional": flags_obj.is_bidirectional, "use_bias": flags_obj.use_bias } dataset_name = "LibriSpeech" benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info("deep_speech", dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) per_device_batch_size = distribution_utils.per_device_batch_size( flags_obj.batch_size, num_gpus) def input_fn_train(): return dataset.input_fn( per_device_batch_size, train_speech_dataset) def input_fn_eval(): return dataset.input_fn( per_device_batch_size, eval_speech_dataset) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info("Starting a training cycle: %d/%d", cycle_index + 1, total_training_cycle) # Perform batch_wise dataset shuffling train_speech_dataset.entries = dataset.batch_wise_dataset_shuffle( train_speech_dataset.entries, cycle_index, flags_obj.sortagrad, flags_obj.batch_size) estimator.train(input_fn=input_fn_train, hooks=train_hooks) # Evaluation tf.logging.info("Starting to evaluate...") eval_results = evaluate_model( estimator, eval_speech_dataset.speech_labels, eval_speech_dataset.entries, input_fn_eval) # Log the WER and CER results. benchmark_logger.log_evaluation_result(eval_results) tf.logging.info( "Iteration {}: WER = {:.2f}, CER = {:.2f}".format( cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY])) # If some evaluation threshold is met if model_helpers.past_stop_threshold( flags_obj.wer_threshold, eval_results[_WER_KEY]): break