def __init__(self, metadata_file_path, hparams):
        # self.audiopaths_and_text = load_filepaths_and_text(metadata_file_path)

        self._metadata = _read_meta_yyh(metadata_file_path)

        self.mel_dir = hparams.mel_dir

        if hparams.use_phone:
            from text.phones import Phones
            self.phone_class = Phones(hparams.phone_set_file)
            self.text_to_sequence = self.phone_class.text_to_sequence
        else:
            from text import text_to_sequence
            self.text_to_sequence = text_to_sequence

        self.text_cleaners = hparams.text_cleaners
        self.is_multi_speakers = hparams.is_multi_speakers
        self.is_multi_styles = hparams.is_multi_styles

        # random.seed(1234)
        random.shuffle(self._metadata)
        self.utt_list = []
        self.utt_data = {}
        self.utt_mels = []
        for i in range(len(self._metadata)):
            item = self._metadata[i]
            self.utt_list.append(
                {"mel_frame": int(item.strip().split('|')[-1])})
            self.utt_data[str(i)] = {
                "text":
                torch.IntTensor(
                    self.text_to_sequence(item.strip().split('|')[5],
                                          self.text_cleaners)),
                "mel":
                torch.from_numpy(
                    np.load(self.mel_dir + 'out_' +
                            item.strip().split('|')[0] + '.npy')) * 8.0 - 4.0,
                "dur":
                DurationCalculator._calculate_duration(
                    torch.from_numpy(
                        np.load(self.mel_dir + 'encdec_' +
                                item.strip().split('|')[0] + '.npy'))),
                "speaker":
                int(item.strip().split('|')[4])
                if self.is_multi_speakers else None,
                "style":
                int(item.strip().split('|')[2])
                if self.is_multi_styles else None
            }
            # self.utt_list.append({"mel_frame": int(np.load(self.mel_dir + 'out_' + item.strip().split('|')[0] + '.npy').shape[0])})
            if hparams.is_refine_style:
                self.utt_mels.append((torch.from_numpy(
                    np.load(self.mel_dir + 'out_' +
                            item.strip().split('|')[0] + '.npy')) * 8.0 -
                                      4.0).unsqueeze(0))
예제 #2
0
    def __init__(
        self,
        train_config: Optional[Union[Path, str]],
        model_file: Optional[Union[Path, str]] = None,
        threshold: float = 0.5,
        minlenratio: float = 0.0,
        maxlenratio: float = 10.0,
        use_teacher_forcing: bool = False,
        use_att_constraint: bool = False,
        backward_window: int = 1,
        forward_window: int = 3,
        speed_control_alpha: float = 1.0,
        vocoder_conf: dict = None,
        dtype: str = "float32",
        device: str = "cpu",
    ):
        assert check_argument_types()

        model, train_args = TTSTask.build_model_from_file(
            train_config, model_file, device)
        model.to(dtype=getattr(torch, dtype)).eval()
        self.device = device
        self.dtype = dtype
        self.train_args = train_args
        self.model = model
        self.tts = model.tts
        self.normalize = model.normalize
        self.feats_extract = model.feats_extract
        self.duration_calculator = DurationCalculator()
        self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False)
        self.use_teacher_forcing = use_teacher_forcing

        logging.info(f"Normalization:\n{self.normalize}")
        logging.info(f"TTS:\n{self.tts}")

        decode_config = {}
        if isinstance(self.tts, (Tacotron2, Transformer)):
            decode_config.update({
                "threshold": threshold,
                "maxlenratio": maxlenratio,
                "minlenratio": minlenratio,
            })
        if isinstance(self.tts, Tacotron2):
            decode_config.update({
                "use_att_constraint": use_att_constraint,
                "forward_window": forward_window,
                "backward_window": backward_window,
            })
        if isinstance(self.tts, (FastSpeech, FastSpeech2)):
            decode_config.update({"alpha": speed_control_alpha})
        decode_config.update({"use_teacher_forcing": use_teacher_forcing})

        self.decode_config = decode_config

        if vocoder_conf is None:
            vocoder_conf = {}
        if self.feats_extract is not None:
            vocoder_conf.update(self.feats_extract.get_parameters())
        if ("n_fft" in vocoder_conf and "n_shift" in vocoder_conf
                and "fs" in vocoder_conf):
            self.spc2wav = Spectrogram2Waveform(**vocoder_conf)
            logging.info(f"Vocoder: {self.spc2wav}")
        else:
            self.spc2wav = None
            logging.info(
                "Vocoder is not used because vocoder_conf is not sufficient")
예제 #3
0
 def get_dur(self, att_ws):
     return DurationCalculator._calculate_duration(
         torch.from_numpy(np.load(att_ws)))
예제 #4
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_teacher_forcing: bool,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    speed_control_alpha: float,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    model, train_args = TTSTask.build_model_from_file(train_config, model_file, device)
    model.to(dtype=getattr(torch, dtype)).eval()
    tts = model.tts
    normalize = model.normalize
    feats_extract = model.feats_extract
    duration_calculator = DurationCalculator()
    logging.info(f"Normalization:\n{normalize}")
    logging.info(f"TTS:\n{tts}")

    # 3. Build decoding config
    decode_config = {}
    if isinstance(tts, (Tacotron2, Transformer)):
        decode_config.update(
            {
                "threshold": threshold,
                "maxlenratio": maxlenratio,
                "minlenratio": minlenratio,
                "use_teacher_forcing": use_teacher_forcing,
            }
        )
    if isinstance(tts, Tacotron2):
        decode_config.update(
            {
                "use_att_constraint": use_att_constraint,
                "forward_window": forward_window,
                "backward_window": backward_window,
            }
        )
    if isinstance(tts, FastSpeech):
        decode_config.update({"alpha": speed_control_alpha})
    use_speech = check_use_speech_in_inference(tts, decode_config)

    # 4. Build data-iterator
    if not use_speech:
        data_path_and_name_and_type = list(
            filter(lambda x: x[1] != "speech", data_path_and_name_and_type)
        )
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(train_args, False),
        collate_fn=TTSTask.build_collate_fn(train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 5. Build converter from spectrogram to waveform
    if feats_extract is not None:
        vocoder_conf.update(feats_extract.get_parameters())
    if "n_fft" in vocoder_conf and "n_shift" in vocoder_conf and "fs" in vocoder_conf:
        spc2wav = Spectrogram2Waveform(**vocoder_conf)
        logging.info(f"Vocoder: {spc2wav}")
    else:
        spc2wav = None
        logging.info("Vocoder is not used because vocoder_conf is not sufficient")

    # 6. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "speech_shape").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)
    (output_dir / "att_ws").mkdir(parents=True, exist_ok=True)
    (output_dir / "probs").mkdir(parents=True, exist_ok=True)
    (output_dir / "durations").mkdir(parents=True, exist_ok=True)
    (output_dir / "focus_rates").mkdir(parents=True, exist_ok=True)

    # Lazy load to avoid the backend error
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator

    with NpyScpWriter(
        output_dir / "norm", output_dir / "norm/feats.scp",
    ) as norm_writer, NpyScpWriter(
        output_dir / "denorm", output_dir / "denorm/feats.scp"
    ) as denorm_writer, open(
        output_dir / "speech_shape/speech_shape", "w"
    ) as shape_writer, open(
        output_dir / "durations/durations", "w"
    ) as duration_writer, open(
        output_dir / "focus_rates/focus_rates", "w"
    ) as focus_rate_writer:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = to_device(batch, device)

            # Extract features if speech is needed
            if use_speech:
                _speech = (v for k, v in batch.items() if k.startswith("speech"))
                if feats_extract is not None:
                    _speech = feats_extract(*_speech)
                speech, speech_lengths = normalize(*_speech)
                batch.update(speech=speech, speech_lengths=speech_lengths)

            key = keys[0]
            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            _data = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            start_time = time.perf_counter()
            outs, probs, att_ws = tts.inference(**_data, **decode_config)
            insize = next(iter(_data.values())).size(0) + 1
            logging.info(
                "inference speed = {:.1f} frames / sec.".format(
                    int(outs.size(0)) / (time.perf_counter() - start_time)
                )
            )
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(f"output length reaches maximum length ({key}).")
            norm_writer[key] = outs.cpu().numpy()
            shape_writer.write(f"{key} " + ",".join(map(str, outs.shape)) + "\n")

            # NOTE: normalize.inverse is in-place operation
            outs_denorm = normalize.inverse(outs[None])[0][0]
            denorm_writer[key] = outs_denorm.cpu().numpy()

            if att_ws is not None:
                # Save duration and fucus rates
                duration, focus_rate = duration_calculator(att_ws)
                duration_writer.write(
                    f"{key} " + " ".join(map(str, duration.cpu().numpy())) + "\n"
                )
                focus_rate_writer.write(f"{key} {float(focus_rate):.5f}\n")

                # Plot attention weight
                att_ws = att_ws.cpu().numpy()

                if att_ws.ndim == 2:
                    att_ws = att_ws[None][None]
                elif att_ws.ndim != 4:
                    raise RuntimeError(f"Must be 2 or 4 dimension: {att_ws.ndim}")

                w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1])
                fig = plt.Figure(
                    figsize=(
                        w * 1.3 * min(att_ws.shape[0], 2.5),
                        h * 1.3 * min(att_ws.shape[1], 2.5),
                    )
                )
                fig.suptitle(f"{key}")
                axes = fig.subplots(att_ws.shape[0], att_ws.shape[1])
                if len(att_ws) == 1:
                    axes = [[axes]]
                for ax, att_w in zip(axes, att_ws):
                    for ax_, att_w_ in zip(ax, att_w):
                        ax_.imshow(att_w_.astype(np.float32), aspect="auto")
                        ax_.set_xlabel("Input")
                        ax_.set_ylabel("Output")
                        ax_.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax_.yaxis.set_major_locator(MaxNLocator(integer=True))

                fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]})
                fig.savefig(output_dir / f"att_ws/{key}.png")
                fig.clf()

            if probs is not None:
                # Plot stop token prediction
                probs = probs.cpu().numpy()

                fig = plt.Figure()
                ax = fig.add_subplot(1, 1, 1)
                ax.plot(probs)
                ax.set_title(f"{key}")
                ax.set_xlabel("Output")
                ax.set_ylabel("Stop probability")
                ax.set_ylim(0, 1)
                ax.grid(which="both")

                fig.set_tight_layout(True)
                fig.savefig(output_dir / f"probs/{key}.png")
                fig.clf()

            # TODO(kamo): Write scp
            if spc2wav is not None:
                wav = spc2wav(outs_denorm.cpu().numpy())
                sf.write(f"{output_dir}/wav/{key}.wav", wav, spc2wav.fs, "PCM_16")

    # remove duration related files if attention is not provided
    if att_ws is None:
        shutil.rmtree(output_dir / "att_ws")
        shutil.rmtree(output_dir / "probs")
        shutil.rmtree(output_dir / "durations")
        shutil.rmtree(output_dir / "focus_rates")