예제 #1
0
def train_and_evaluate(hparams, model_dir, train_source_files, train_target_files, eval_source_files,
                       eval_target_files, use_multi_gpu):

    interleave_parallelism = get_parallelism(hparams.interleave_cycle_length_cpu_factor,
                                             hparams.interleave_cycle_length_min,
                                             hparams.interleave_cycle_length_max)

    tf.logging.info("Interleave parallelism is %d.", interleave_parallelism)

    def train_input_fn():
        source_and_target_files = list(zip(train_source_files, train_target_files))
        shuffle(source_and_target_files)
        source = [s for s, _ in source_and_target_files]
        target = [t for _, t in source_and_target_files]

        dataset = create_from_tfrecord_files(source, target, hparams,
                                             cycle_length=interleave_parallelism,
                                             buffer_output_elements=hparams.interleave_buffer_output_elements,
                                             prefetch_input_elements=hparams.interleave_prefetch_input_elements)

        zipped = dataset.prepare_and_zip()
        zipped = zipped.cache(hparams.cache_file_name) if hparams.use_cache else zipped
        batched = zipped.filter_by_max_output_length().repeat(count=None).shuffle(
            hparams.suffle_buffer_size).group_by_batch().prefetch(hparams.prefetch_buffer_size)
        return batched.dataset

    def eval_input_fn():
        source_and_target_files = list(zip(eval_source_files, eval_target_files))
        shuffle(source_and_target_files)
        source = tf.data.TFRecordDataset([s for s, _ in source_and_target_files])
        target = tf.data.TFRecordDataset([t for _, t in source_and_target_files])

        dataset = dataset_factory(source, target, hparams)
        zipped = dataset.prepare_and_zip()
        dataset = zipped.filter_by_max_output_length().repeat().group_by_batch(batch_size=1)
        return dataset.dataset

    distribution = tf.contrib.distribute.MirroredStrategy() if use_multi_gpu else None

    run_config = tf.estimator.RunConfig(save_summary_steps=hparams.save_summary_steps,
                                        save_checkpoints_steps=hparams.save_checkpoints_steps,
                                        keep_checkpoint_max=hparams.keep_checkpoint_max,
                                        log_step_count_steps=hparams.log_step_count_steps,
                                        train_distribute=distribution)
    estimator = tacotron_model_factory(hparams, model_dir, run_config)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                      steps=hparams.num_evaluation_steps,
                                      throttle_secs=hparams.eval_throttle_secs,
                                      start_delay_secs=hparams.eval_start_delay_secs)

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def predict(hparams, model_dir, checkpoint_path, output_dir, test_source_files,
            test_target_files):
    def predict_input_fn():
        source = tf.data.TFRecordDataset(list(test_source_files))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = dataset_factory(source, target, hparams)
        batched = dataset.prepare_and_zip().group_by_batch(
            batch_size=1).merge_target_to_source()
        return batched.dataset

    estimator = tacotron_model_factory(hparams, model_dir, None)

    predictions = map(
        lambda p: PredictedCodes(p["id"], p["key"], p["codes"], p[
            "ground_truth_codes"], p["alignment"], p.get("alignment2"),
                                 p.get("alignment3"), p.get("alignment4"),
                                 p.get("alignment5"), p.get("alignment6"),
                                 p["source"], p["text"], p.get("accent_type")),
        estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path))

    count = 0
    for v in predictions:
        count += 1
        key = v.key.decode('utf-8')
        filename = f"{key}.mfbsp"
        filepath = os.path.join(output_dir, filename)
        codes = v.codes
        #        assert codes.shape[1] == 512
        codes.tofile(filepath, format='<f4')
        text = v.text.decode("utf-8")
        print(key, codes.shape[0], v.ground_truth_codes.shape[0], text)
        #        plot_filename = f"{key}.png"
        #        plot_filepath = os.path.join(output_dir, plot_filename)
        alignments = list(
            filter(lambda x: x is not None, [
                v.alignment, v.alignment2, v.alignment3, v.alignment4,
                v.alignment5, v.alignment6
            ]))

        #        plot_predictions(alignments, v.ground_truth_codes, v.predicted_codes, v.predicted_mel_postnet,
        #                         text, v.key, plot_filepath)
        prediction_filename = f"{key}.tfrecord"
        prediction_filepath = os.path.join(output_dir, prediction_filename)
        write_prediction_result(v.id, key, alignments, codes,
                                v.ground_truth_codes, text, v.source,
                                v.accent_type, prediction_filepath)
        if count == 10:
            sys.exit()
def predict(hparams, model_dir, checkpoint_path, output_dir, test_source_files,
            test_target_files):
    def predict_input_fn():
        source = tf.data.TFRecordDataset(list(test_source_files))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = dataset_factory(source, target, hparams)
        batched = dataset.prepare_and_zip().group_by_batch(
            batch_size=1).merge_target_to_source()
        return batched.dataset

    estimator = tacotron_model_factory(hparams, model_dir, None)

    predictions = map(
        lambda p: PredictedMel(
            p["id"], p["key"], p["mel"], p.get("mel_postnet"), p["mel"].shape[
                1], p["mel"].shape[0], p["ground_truth_mel"], p["alignment"],
            p.get("alignment2"), p.get("alignment3"), p.get("alignment4"),
            p.get("alignment5"), p.get("alignment6"), p["source"], p["text"],
            p.get("accent_type")),
        estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path))

    for v in predictions:
        key = v.key.decode('utf-8')
        mel_filename = f"{key}.{hparams.predicted_mel_extension}"
        mel_filepath = os.path.join(output_dir, mel_filename)
        mel = v.predicted_mel_postnet if hparams.use_postnet_v2 else v.predicted_mel
        assert mel.shape[1] == hparams.num_mels
        mel.tofile(mel_filepath, format='<f4')
        text = v.text.decode("utf-8")
        plot_filename = f"{key}.png"
        plot_filepath = os.path.join(output_dir, plot_filename)
        alignments = list(
            filter(lambda x: x is not None, [
                v.alignment, v.alignment2, v.alignment3, v.alignment4,
                v.alignment5, v.alignment6
            ]))

        plot_predictions(alignments, v.ground_truth_mel, v.predicted_mel,
                         v.predicted_mel_postnet, text, v.key, plot_filepath)
        prediction_filename = f"{key}.tfrecord"
        prediction_filepath = os.path.join(output_dir, prediction_filename)
        write_prediction_result(v.id, key, alignments, mel, v.ground_truth_mel,
                                text, v.source, v.accent_type,
                                prediction_filepath)