def train_and_evaluate(hparams, model_dir, train_source_files, train_target_files, eval_source_files, eval_target_files, use_multi_gpu): interleave_parallelism = get_parallelism(hparams.interleave_cycle_length_cpu_factor, hparams.interleave_cycle_length_min, hparams.interleave_cycle_length_max) tf.logging.info("Interleave parallelism is %d.", interleave_parallelism) def train_input_fn(): source_and_target_files = list(zip(train_source_files, train_target_files)) shuffle(source_and_target_files) source = [s for s, _ in source_and_target_files] target = [t for _, t in source_and_target_files] dataset = create_from_tfrecord_files(source, target, hparams, cycle_length=interleave_parallelism, buffer_output_elements=hparams.interleave_buffer_output_elements, prefetch_input_elements=hparams.interleave_prefetch_input_elements) zipped = dataset.prepare_and_zip() zipped = zipped.cache(hparams.cache_file_name) if hparams.use_cache else zipped batched = zipped.filter_by_max_output_length().repeat(count=None).shuffle( hparams.suffle_buffer_size).group_by_batch().prefetch(hparams.prefetch_buffer_size) return batched.dataset def eval_input_fn(): source_and_target_files = list(zip(eval_source_files, eval_target_files)) shuffle(source_and_target_files) source = tf.data.TFRecordDataset([s for s, _ in source_and_target_files]) target = tf.data.TFRecordDataset([t for _, t in source_and_target_files]) dataset = dataset_factory(source, target, hparams) zipped = dataset.prepare_and_zip() dataset = zipped.filter_by_max_output_length().repeat().group_by_batch(batch_size=1) return dataset.dataset distribution = tf.contrib.distribute.MirroredStrategy() if use_multi_gpu else None run_config = tf.estimator.RunConfig(save_summary_steps=hparams.save_summary_steps, save_checkpoints_steps=hparams.save_checkpoints_steps, keep_checkpoint_max=hparams.keep_checkpoint_max, log_step_count_steps=hparams.log_step_count_steps, train_distribute=distribution) estimator = tacotron_model_factory(hparams, model_dir, run_config) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=hparams.num_evaluation_steps, throttle_secs=hparams.eval_throttle_secs, start_delay_secs=hparams.eval_start_delay_secs) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def predict(hparams, model_dir, checkpoint_path, output_dir, test_source_files, test_target_files): def predict_input_fn(): source = tf.data.TFRecordDataset(list(test_source_files)) target = tf.data.TFRecordDataset(list(test_target_files)) dataset = dataset_factory(source, target, hparams) batched = dataset.prepare_and_zip().group_by_batch( batch_size=1).merge_target_to_source() return batched.dataset estimator = tacotron_model_factory(hparams, model_dir, None) predictions = map( lambda p: PredictedCodes(p["id"], p["key"], p["codes"], p[ "ground_truth_codes"], p["alignment"], p.get("alignment2"), p.get("alignment3"), p.get("alignment4"), p.get("alignment5"), p.get("alignment6"), p["source"], p["text"], p.get("accent_type")), estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path)) count = 0 for v in predictions: count += 1 key = v.key.decode('utf-8') filename = f"{key}.mfbsp" filepath = os.path.join(output_dir, filename) codes = v.codes # assert codes.shape[1] == 512 codes.tofile(filepath, format='<f4') text = v.text.decode("utf-8") print(key, codes.shape[0], v.ground_truth_codes.shape[0], text) # plot_filename = f"{key}.png" # plot_filepath = os.path.join(output_dir, plot_filename) alignments = list( filter(lambda x: x is not None, [ v.alignment, v.alignment2, v.alignment3, v.alignment4, v.alignment5, v.alignment6 ])) # plot_predictions(alignments, v.ground_truth_codes, v.predicted_codes, v.predicted_mel_postnet, # text, v.key, plot_filepath) prediction_filename = f"{key}.tfrecord" prediction_filepath = os.path.join(output_dir, prediction_filename) write_prediction_result(v.id, key, alignments, codes, v.ground_truth_codes, text, v.source, v.accent_type, prediction_filepath) if count == 10: sys.exit()
def predict(hparams, model_dir, checkpoint_path, output_dir, test_source_files, test_target_files): def predict_input_fn(): source = tf.data.TFRecordDataset(list(test_source_files)) target = tf.data.TFRecordDataset(list(test_target_files)) dataset = dataset_factory(source, target, hparams) batched = dataset.prepare_and_zip().group_by_batch( batch_size=1).merge_target_to_source() return batched.dataset estimator = tacotron_model_factory(hparams, model_dir, None) predictions = map( lambda p: PredictedMel( p["id"], p["key"], p["mel"], p.get("mel_postnet"), p["mel"].shape[ 1], p["mel"].shape[0], p["ground_truth_mel"], p["alignment"], p.get("alignment2"), p.get("alignment3"), p.get("alignment4"), p.get("alignment5"), p.get("alignment6"), p["source"], p["text"], p.get("accent_type")), estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path)) for v in predictions: key = v.key.decode('utf-8') mel_filename = f"{key}.{hparams.predicted_mel_extension}" mel_filepath = os.path.join(output_dir, mel_filename) mel = v.predicted_mel_postnet if hparams.use_postnet_v2 else v.predicted_mel assert mel.shape[1] == hparams.num_mels mel.tofile(mel_filepath, format='<f4') text = v.text.decode("utf-8") plot_filename = f"{key}.png" plot_filepath = os.path.join(output_dir, plot_filename) alignments = list( filter(lambda x: x is not None, [ v.alignment, v.alignment2, v.alignment3, v.alignment4, v.alignment5, v.alignment6 ])) plot_predictions(alignments, v.ground_truth_mel, v.predicted_mel, v.predicted_mel_postnet, text, v.key, plot_filepath) prediction_filename = f"{key}.tfrecord" prediction_filepath = os.path.join(output_dir, prediction_filename) write_prediction_result(v.id, key, alignments, mel, v.ground_truth_mel, text, v.source, v.accent_type, prediction_filepath)