Exemplo n.º 1
0
    def test_waveglow_training(self):
        data_layer = nemo_tts.AudioDataLayer(
            manifest_filepath=self.manifest_filepath,
            n_segments=4000,
            batch_size=4,
        )
        preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(
            window_size=None,
            window_stride=None,
            n_window_size=512,
            n_window_stride=128,
            normalize=None,
            preemph=None,
            dither=0,
            mag_power=1.0,
            pad_value=-11.52,
        )
        waveglow = nemo_tts.WaveGlowNM(
            n_mel_channels=64,
            n_flows=6,
            n_group=4,
            n_early_every=4,
            n_early_size=2,
            n_wn_layers=4,
            n_wn_channels=256,
            wn_kernel_size=3,
        )
        waveglow_loss = nemo_tts.WaveGlowLoss()

        # DAG
        audio, audio_len, = data_layer()
        spec_target, _ = preprocessing(input_signal=audio, length=audio_len)

        z, log_s_list, log_det_W_list = waveglow(mel_spectrogram=spec_target,
                                                 audio=audio)
        loss_t = waveglow_loss(z=z,
                               log_s_list=log_s_list,
                               log_det_W_list=log_det_W_list)

        callback = nemo.core.SimpleLossLoggerCallback(
            tensors=[loss_t],
            print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'
                                              ),
        )
        # Instantiate an optimizer to perform `train` action
        neural_factory = nemo.core.NeuralModuleFactory(
            backend=nemo.core.Backend.PyTorch,
            local_rank=None,
            create_tb_writer=False,
        )
        optimizer = neural_factory.get_trainer()
        optimizer.train(
            [loss_t],
            callbacks=[callback],
            optimizer="sgd",
            optimization_params={
                "num_epochs": 10,
                "lr": 0.0003
            },
        )
Exemplo n.º 2
0
def create_NMs(waveglow_params):
    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        **waveglow_params["AudioToMelSpectrogramPreprocessor"])
    waveglow = nemo_tts.WaveGlowNM(22050, **waveglow_params["WaveGlowNM"])
    waveglow_loss = nemo_tts.WaveGlowLoss(sample_rate=22050)

    logging.info('================================')
    logging.info(f"Total number of parameters: {waveglow.num_weights}")
    logging.info('================================')
    return (data_preprocessor, waveglow, waveglow_loss)
Exemplo n.º 3
0
    def test_waveglow_training(self):
        """Integtaion test that instantiates a smaller WaveGlow model and tests training with the sample asr data.
        Training is run for 3 forward and backward steps and asserts that loss after 3 steps is smaller than the loss
        at the first step.
        """
        data_layer = nemo_tts.AudioDataLayer(
            manifest_filepath=self.manifest_filepath, n_segments=4000, batch_size=4, sample_rate=16000
        )
        preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(
            window_size=None,
            window_stride=None,
            n_window_size=512,
            n_window_stride=128,
            normalize=None,
            preemph=None,
            dither=0,
            mag_power=1.0,
            pad_value=-11.52,
        )
        waveglow = nemo_tts.WaveGlowNM(
            n_mel_channels=64,
            n_flows=6,
            n_group=4,
            n_early_every=4,
            n_early_size=2,
            n_wn_layers=4,
            n_wn_channels=256,
            wn_kernel_size=3,
            sample_rate=16000,
        )
        waveglow_loss = nemo_tts.WaveGlowLoss(sample_rate=16000)

        # DAG
        audio, audio_len, = data_layer()
        spec_target, _ = preprocessing(input_signal=audio, length=audio_len)

        z, log_s_list, log_det_W_list = waveglow(mel_spectrogram=spec_target, audio=audio)
        loss_t = waveglow_loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list)

        loss_list = []
        callback = SimpleLossLoggerCallback(
            tensors=[loss_t], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1
        )
        # Instantiate an optimizer to perform `train` action
        optimizer = PtActions()
        optimizer.train(
            [loss_t], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 3, "lr": 0.01}
        )

        # Assert that training loss went down
        assert loss_list[-1] < loss_list[0]