예제 #1
0
파일: writer.py 프로젝트: zwjgit/VocGAN
class MyWriter(SummaryWriter):
    def __init__(self, hp, logdir):
        super(MyWriter, self).__init__(logdir)
        self.sample_rate = hp.audio.sampling_rate
        self.stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                                 hop_length=hp.audio.hop_length,
                                 win_length=hp.audio.win_length,
                                 n_mel_channels=hp.audio.n_mel_channels,
                                 sampling_rate=hp.audio.sampling_rate,
                                 mel_fmin=hp.audio.mel_fmin,
                                 mel_fmax=hp.audio.mel_fmax)
        self.is_first = True

    def log_training(self, g_loss, d_loss, adv_loss, step):
        self.add_scalar('train.g_loss', g_loss, step)
        self.add_scalar('train.d_loss', d_loss, step)
        self.add_scalar('train.adv_loss', adv_loss, step)

    def log_validation(self, g_loss, d_loss, adv_loss, generator,
                       discriminator, target, prediction, step):
        self.add_scalar('validation.g_loss', g_loss, step)
        self.add_scalar('validation.d_loss', d_loss, step)
        self.add_scalar('validation.adv_loss', adv_loss, step)
        self.add_audio('raw_audio_predicted', prediction, step,
                       self.sample_rate)
        self.add_image('waveform_predicted',
                       plot_waveform_to_numpy(prediction), step)
        wav = torch.from_numpy(prediction).unsqueeze(0)
        mel = self.stft.mel_spectrogram(wav)  # mel [1, num_mel, T]
        self.add_image('melspectrogram_prediction',
                       plot_spectrogram_to_numpy(
                           mel.squeeze(0).data.cpu().numpy()),
                       step,
                       dataformats='HWC')
        self.log_histogram(generator, step)
        self.log_histogram(discriminator, step)

        if self.is_first:
            self.add_audio('raw_audio_target', target, step, self.sample_rate)
            self.add_image('waveform_target', plot_waveform_to_numpy(target),
                           step)
            wav = torch.from_numpy(target).unsqueeze(0)
            mel = self.stft.mel_spectrogram(wav)  # mel [1, num_mel, T]
            self.add_image('melspectrogram_target',
                           plot_spectrogram_to_numpy(
                               mel.squeeze(0).data.cpu().numpy()),
                           step,
                           dataformats='HWC')
            self.is_first = False

    def log_evaluation(self, generated, step, name):
        self.add_audio(f'evaluation/{name}', generated, step, self.sample_rate)

    def log_histogram(self, model, step):
        for tag, value in model.named_parameters():
            self.add_histogram(tag.replace('.', '/'),
                               value.cpu().detach().numpy(), step)
def main(hp, args):
    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)
    mel_path = hp.data.mel_path
    os.makedirs(mel_path, exist_ok=True)

    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        assert sr == hp.audio.sampling_rate, \
            "sample rate mismatch. expected %d, got %d at %s" % \
            (hp.audio.sampling_rate, sr, wavpath)

        if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
            wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
                    mode='constant', constant_values=0.0)

        wav = torch.from_numpy(wav).unsqueeze(0)
        mel = stft.mel_spectrogram(wav)  # mel [1, num_mel, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        id = os.path.basename(wavpath).split(".")[0]
        np.save('{}/{}.npy'.format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
예제 #3
0
def main(hp, args):
    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.flac'), recursive=True)

    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_flac_np(wavpath)
        assert sr == hp.audio.sampling_rate, \
            "sample rate mismatch. expected %d, got %d at %s" % \
            (hp.audio.sampling_rate, sr, wavpath)
        
        if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
            wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
                    mode='constant', constant_values=0.0)

        wav = torch.from_numpy(wav).unsqueeze(0)
        mel = stft.mel_spectrogram(wav)

        melpath = wavpath.replace('.flac', '.mel')
        torch.save(mel, melpath)
예제 #4
0
def main(hp, args):
    stft = TacotronSTFT(filter_length=hp.filter_length,
                        hop_length=hp.hop_length,
                        win_length=hp.win_length,
                        n_mel_channels=hp.n_mel_channels,
                        sampling_rate=hp.sampling_rate,
                        mel_fmin=hp.mel_fmin,
                        mel_fmax=hp.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)

    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        assert sr == hp.sampling_rate, \
            "sample rate mismatch. expected %d, got %d at %s" % \
            (hp.sampling_rate, sr, wavpath)
        
        if len(wav) < hp.segment_length + hp.pad_short:
            wav = np.pad(wav, (0, hp.segment_length + hp.pad_short - len(wav)), \
                    mode='constant', constant_values=0.0)

        # f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate, frame_period=hp.hop_length/hp.sampling_rate*1000)

        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, energy = stft.mel_spectrogram(wav)
        # f0 = f0[:sum(duration)]

        melpath = os.path.join(args.out_path, os.path.basename(wavpath.replace('.wav', '.npy')))
        # amppath = os.path.join(args.out_path, os.path.basename(wavpath.replace('.wav', '_amp.npy')))
        np.save(melpath, mel.squeeze(0).transpose(0,1).numpy())
예제 #5
0
def extract_mel(wav_dir, fs, speaker_id_pos, filter_length, hop_length,
                mel_fmin, mel_fmax):
    """ Extract mel-spectrogram from audio waveform """
    mel_extractor = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 n_mel_channels=80,
                                 sampling_rate=fs,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax).cuda()
    audio_dataset = AudioDataset(wav_dir=wav_dir,
                                 n_speaker=None,
                                 speaker_id_pos=speaker_id_pos)
    audio_collate_fn = AudioCollateFn()
    audio_loader = DataLoader(audio_dataset,
                              batch_size=32,
                              shuffle=False,
                              collate_fn=audio_collate_fn,
                              num_workers=0)

    for batch in progressbar(audio_loader):
        audio, lengths, audio_path = batch
        audio = audio.cuda(0)
        mel = mel_extractor.mel_spectrogram(audio).detach().cpu().numpy()
        for i, input_wav in enumerate(audio_path):
            mel_fname = input_wav.split("/")[-1].replace(".wav", "")
            output_mel_dir = input_wav.replace("wav16", "mel16k")
            output_mel_dir = "/".join(output_mel_dir.split("/")[:-1])
            if not exists(output_mel_dir):
                os.makedirs(output_mel_dir)
            _mel = mel[i, :, :int(math.ceil(lengths[i] / hop_length))]
            np.save(join(output_mel_dir, mel_fname), _mel)
예제 #6
0
def main(args):
    stft = TacotronSTFT(filter_length=hp.n_fft,
                        hop_length=hp.hop_length,
                        win_length=hp.win_length,
                        n_mel_channels=hp.n_mels,
                        sampling_rate=hp.sample_rate,
                        mel_fmin=hp.fmin,
                        mel_fmax=hp.fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)
    mel_path = os.path.join(hp.data_dir, 'mels')
    energy_path = os.path.join(hp.data_dir, 'energy')
    pitch_path = os.path.join(hp.data_dir, 'pitch')
    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        p = pitch(wav)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]
        id = os.path.basename(wavpath).split(".")[0]
        np.save('{}/{}.npy'.format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save('{}/{}.npy'.format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save('{}/{}.npy'.format(pitch_path, id), p, allow_pickle=False)
예제 #7
0
def preprocess(data_path, hp, file):
    stft = TacotronSTFT(
        filter_length=hp.audio.n_fft,
        hop_length=hp.audio.hop_length,
        win_length=hp.audio.win_length,
        n_mel_channels=hp.audio.n_mels,
        sampling_rate=hp.audio.sample_rate,
        mel_fmin=hp.audio.fmin,
        mel_fmax=hp.audio.fmax,
    )

    mel_path = os.path.join(hp.data.data_dir, "mels")
    energy_path = os.path.join(hp.data.data_dir, "energy")
    pitch_path = os.path.join(hp.data.data_dir, "pitch")
    avg_mel_phon = os.path.join(hp.data.data_dir, "avg_mel_ph")

    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    os.makedirs(avg_mel_phon, exist_ok=True)
    print("Sample Rate : ", hp.audio.sample_rate)

    with open("{}".format(file), encoding="utf-8") as f:
        _metadata = [line.strip().split("|") for line in f]
    for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"):
        wavpath = os.path.join(data_path, metadata[4])
        sr, wav = read_wav_np(wavpath, hp.audio.sample_rate)

        dur = str_to_int_list(metadata[2])
        dur = torch.from_numpy(np.array(dur))

        p = pitch(wav, hp)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]

        avg_mel_ph = _average_mel_by_duration(mel, dur)  # [num_mel, L]
        assert (avg_mel_ph.shape[0] == dur.shape[-1])

        id = os.path.basename(wavpath).split(".")[0]
        np.save("{}/{}.npy".format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
        np.save("{}/{}.npy".format(avg_mel_phon, id),
                avg_mel_ph.numpy(),
                allow_pickle=False)
예제 #8
0
class Mel2Samp(torch.utils.data.Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """
    def __init__(self, training_files, segment_length, filter_length,
                 hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
        self.audio_files = files_to_list(training_files)
        random.seed(1234)
        random.shuffle(self.audio_files)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin, mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate

    def get_mel(self, audio):
        audio_norm = audio / MAX_WAV_VALUE
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        return melspec

    def __getitem__(self, index):
        # Read audio
        filename = self.audio_files[index]
        audio, sampling_rate = load_wav_to_torch(filename)
        if sampling_rate != self.sampling_rate:
            raise ValueError("{} SR doesn't match target {} SR".format(
                sampling_rate, self.sampling_rate))

        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = random.randint(0, max_audio_start)
            audio = audio[audio_start:audio_start+self.segment_length]
        else:
            audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data

        mel = self.get_mel(audio)
        audio = audio / MAX_WAV_VALUE

        return (mel, audio)

    def __len__(self):
        return len(self.audio_files)
예제 #9
0
def main(hp, args):
    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)

    save_train_mel_path = 'melgan_train_mel_data'
    save_val_mel_path = 'melgan_val_mel_data'
    os.makedirs(save_train_mel_path, exist_ok=True)
    os.makedirs(save_val_mel_path, exist_ok=True)

    random.shuffle(wav_files)

    count = 0
    #for wavpath in wav_files:
    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        assert sr == hp.audio.sampling_rate, \
            "sample rate mismatch. expected %d, got %d at %s" % \
            (hp.audio.sampling_rate, sr, wavpath)

        if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
            wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
                    mode='constant', constant_values=0.0)

        wav = torch.from_numpy(wav).unsqueeze(0)
        mel = stft.mel_spectrogram(wav)

        wav_name = wavpath.split('/')[5]
        melpath = wavpath.replace('.wav', '.mel')
        mel_name = melpath.split('/')[5]

        if count < 300:
            final_mel_path = os.path.join(save_val_mel_path, mel_name)
            final_wav_path = os.path.join(save_val_mel_path, wav_name)
        else:
            final_mel_path = os.path.join(save_train_mel_path, mel_name)
            final_wav_path = os.path.join(save_train_mel_path, wav_name)
        torch.save(mel, final_mel_path)
        shutil.copy(wavpath, final_wav_path)
        count += 1
예제 #10
0
def main(args, hp):
    stft = TacotronSTFT(
        filter_length=hp.audio.n_fft,
        hop_length=hp.audio.hop_length,
        win_length=hp.audio.win_length,
        n_mel_channels=hp.audio.n_mels,
        sampling_rate=hp.audio.sample_rate,
        mel_fmin=hp.audio.fmin,
        mel_fmax=hp.audio.fmax,
    )

    wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"),
                          recursive=True)
    mel_path = os.path.join(hp.data.data_dir, "mels")
    energy_path = os.path.join(hp.data.data_dir, "energy")
    pitch_path = os.path.join(hp.data.data_dir, "pitch")
    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    print("Sample Rate : ", hp.audio.sample_rate)
    for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"):
        sr, wav = read_wav_np(wavpath, hp.audio.sample_rate)
        p = pitch(wav, hp)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]
        id = os.path.basename(wavpath).split(".")[0]
        np.save("{}/{}.npy".format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
예제 #11
0
def main(hp, args):
    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)

    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        assert sr == hp.audio.sampling_rate, \
            "sample rate mismatch. expected %d, got %d at %s" % \
            (hp.audio.sampling_rate, sr, wavpath)

        wav = torch.from_numpy(wav).unsqueeze(0)
        mel = stft.mel_spectrogram(wav)

        melpath = wavpath.replace('.wav', '.mel')
        torch.save(mel, melpath)
예제 #12
0
def main(args, hp):
    stft = TacotronSTFT(filter_length=hp.filter_length,
                        hop_length=hp.hop_length,
                        win_length=hp.win_length,
                        n_mel_channels=hp.n_mel_channels,
                        sampling_rate=hp.sampling_rate,
                        mel_fmin=hp.mel_fmin,
                        mel_fmax=hp.mel_fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)
    mel_path = hp.data_path
    os.makedirs(mel_path, exist_ok=True)
    print("Sample Rate : ", hp.sampling_rate)
    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath, hp.sampling_rate)
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        id = os.path.basename(wavpath).split(".")[0]
        np.save('{}/{}.npy'.format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
예제 #13
0
def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
          hp_str):
    model_g = Generator(hp.audio.n_mel_channels).cuda()
    model_d = MultiScaleDiscriminator(hp.model.num_D, hp.model.ndf,
                                      hp.model.n_layers,
                                      hp.model.downsampling_factor,
                                      hp.model.disc_out).cuda()
    model_d_mpd = MPD().cuda()

    optim_g = torch.optim.Adam(model_g.parameters(),
                               lr=hp.train.adam.lr,
                               betas=(hp.train.adam.beta1,
                                      hp.train.adam.beta2))
    optim_d = torch.optim.Adam(itertools.chain(model_d.parameters(),
                                               model_d_mpd.parameters()),
                               lr=hp.train.adam.lr,
                               betas=(hp.train.adam.beta1,
                                      hp.train.adam.beta2))

    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    # githash = get_commit_hash()

    init_epoch = -1
    step = 0

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model_g.load_state_dict(checkpoint['model_g'])
        model_d.load_state_dict(checkpoint['model_d'])
        model_d_mpd.load_state_dict(checkpoint['model_d_mpd'])
        optim_g.load_state_dict(checkpoint['optim_g'])
        optim_d.load_state_dict(checkpoint['optim_d'])
        step = checkpoint['step']
        init_epoch = checkpoint['epoch']

        if hp_str != checkpoint['hp_str']:
            logger.warning(
                "New hparams is different from checkpoint. Will use new.")

        # if githash != checkpoint['githash']:
        #     logger.warning("Code might be different: git hash is different.")
        #     logger.warning("%s -> %s" % (checkpoint['githash'], githash))

    else:
        logger.info("Starting new training run.")

    # this accelerates training when the size of minibatch is always consistent.
    # if not consistent, it'll horribly slow down.
    torch.backends.cudnn.benchmark = True

    try:
        model_g.train()
        model_d.train()
        stft_loss = MultiResolutionSTFTLoss()
        criterion = torch.nn.MSELoss().cuda()
        l1loss = torch.nn.L1Loss()

        for epoch in itertools.count(init_epoch + 1):
            if epoch % hp.log.validation_interval == 0:
                with torch.no_grad():
                    validate(hp, model_g, model_d, model_d_mpd, valloader,
                             stft_loss, l1loss, criterion, stft, writer, step)

            trainloader.dataset.shuffle_mapping()
            loader = tqdm.tqdm(trainloader, desc='Loading train data')
            avg_g_loss = []
            avg_d_loss = []
            avg_adv_loss = []
            for (melG, audioG), (melD, audioD) in loader:
                melG = melG.cuda()  # torch.Size([16, 80, 64])
                audioG = audioG.cuda()  # torch.Size([16, 1, 16000])
                melD = melD.cuda()  # torch.Size([16, 80, 64])
                audioD = audioD.cuda()  # torch.Size([16, 1, 16000]

                # generator
                optim_g.zero_grad()
                fake_audio = model_g(
                    melG)[:, :, :hp.audio.
                          segment_length]  # torch.Size([16, 1, 12800])

                loss_g = 0.0

                sc_loss, mag_loss = stft_loss(
                    fake_audio[:, :, :audioG.size(2)].squeeze(1),
                    audioG.squeeze(1))
                loss_g += sc_loss + mag_loss  # STFT Loss

                adv_loss = 0.0
                loss_mel = 0.0
                if step > hp.train.discriminator_train_start_steps:
                    disc_real = model_d(audioG)
                    disc_fake = model_d(fake_audio)
                    # for multi-scale discriminator

                    for feats_fake, score_fake in disc_fake:
                        # adv_loss += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
                        adv_loss += criterion(score_fake,
                                              torch.ones_like(score_fake))
                    adv_loss = adv_loss / len(disc_fake)  # len(disc_fake) = 3

                    # MPD Adverserial loss
                    out1, out2, out3, out4, out5 = model_d_mpd(fake_audio)
                    adv_mpd_loss = criterion(out1, torch.ones_like(out1)) + criterion(out2, torch.ones_like(out2)) + \
                                        criterion(out3, torch.ones_like(out3)) + criterion(out4, torch.ones_like(out4)) + \
                                        criterion(out5, torch.ones_like(out5))
                    adv_mpd_loss = adv_mpd_loss / 5
                    adv_loss = adv_loss + adv_mpd_loss  # Adv Loss

                    # Mel Loss
                    mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
                    loss_mel += l1loss(melG[:, :, :mel_fake.size(2)],
                                       mel_fake.cuda())  # Mel L1 loss
                    loss_g += hp.model.lambda_mel * loss_mel

                    if hp.model.feat_loss:
                        for (feats_fake,
                             score_fake), (feats_real,
                                           _) in zip(disc_fake, disc_real):
                            for feat_f, feat_r in zip(feats_fake, feats_real):
                                adv_loss += hp.model.feat_match * torch.mean(
                                    torch.abs(feat_f - feat_r))

                    loss_g += hp.model.lambda_adv * adv_loss

                loss_g.backward()
                optim_g.step()

                # discriminator
                loss_d_avg = 0.0
                if step > hp.train.discriminator_train_start_steps:
                    fake_audio = model_g(melD)[:, :, :hp.audio.segment_length]
                    fake_audio = fake_audio.detach()
                    loss_d_sum = 0.0
                    for _ in range(hp.train.rep_discriminator):
                        optim_d.zero_grad()
                        disc_fake = model_d(fake_audio)
                        disc_real = model_d(audioD)
                        loss_d = 0.0
                        loss_d_real = 0.0
                        loss_d_fake = 0.0
                        for (_, score_fake), (_, score_real) in zip(
                                disc_fake, disc_real):
                            loss_d_real += criterion(
                                score_real, torch.ones_like(score_real))
                            loss_d_fake += criterion(
                                score_fake, torch.zeros_like(score_fake))
                        loss_d_real = loss_d_real / len(
                            disc_real)  # len(disc_real) = 3
                        loss_d_fake = loss_d_fake / len(
                            disc_fake)  # len(disc_fake) = 3
                        loss_d += loss_d_real + loss_d_fake  # MSD loss

                        loss_d_sum += loss_d

                        # MPD Adverserial loss
                        out1, out2, out3, out4, out5 = model_d_mpd(fake_audio)
                        out1_real, out2_real, out3_real, out4_real, out5_real = model_d_mpd(
                            audioD)
                        loss_mpd_fake = criterion(out1, torch.zeros_like(out1)) + criterion(out2, torch.zeros_like(out2)) + \
                                            criterion(out3, torch.zeros_like(out3)) + criterion(out4, torch.zeros_like(out4)) + \
                                            criterion(out5, torch.zeros_like(out5))
                        loss_mpd_real = criterion(out1_real, torch.ones_like(out1_real)) + criterion(out2_real, torch.ones_like(out2_real)) + \
                                            criterion(out3_real, torch.ones_like(out3_real)) + criterion(out4_real, torch.ones_like(out4_real)) + \
                                            criterion(out5_real, torch.ones_like(out5_real))
                        loss_mpd = (loss_mpd_fake +
                                    loss_mpd_real) / 5  # MPD Loss
                        loss_d += loss_mpd
                        loss_d.backward()
                        optim_d.step()
                        loss_d_sum += loss_mpd

                    loss_d_avg = loss_d_sum / hp.train.rep_discriminator
                    loss_d_avg = loss_d_avg.item()

                step += 1
                # logging
                loss_g = loss_g.item()
                avg_g_loss.append(loss_g)
                avg_d_loss.append(loss_d_avg)
                avg_adv_loss.append(adv_loss)

                if any([
                        loss_g > 1e8,
                        math.isnan(loss_g), loss_d_avg > 1e8,
                        math.isnan(loss_d_avg)
                ]):
                    logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" %
                                 (loss_g, loss_d_avg, step))
                    raise Exception("Loss exploded")

                if step % hp.log.summary_interval == 0:
                    writer.log_training(loss_g, loss_d_avg, adv_loss, loss_mel,
                                        step)
                    loader.set_description(
                        "Avg : g %.04f d %.04f ad %.04f| step %d" %
                        (sum(avg_g_loss) / len(avg_g_loss),
                         sum(avg_d_loss) / len(avg_d_loss),
                         sum(avg_adv_loss) / len(avg_adv_loss), step))
            if epoch % hp.log.save_interval == 0:
                save_path = os.path.join(pt_dir,
                                         '%s_%04d.pt' % (args.name, epoch))
                torch.save(
                    {
                        'model_g': model_g.state_dict(),
                        'model_d': model_d.state_dict(),
                        'model_d_mpd': model_d_mpd.state_dict(),
                        'optim_g': optim_g.state_dict(),
                        'optim_d': optim_d.state_dict(),
                        'step': step,
                        'epoch': epoch,
                        'hp_str': hp_str
                    }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)

    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
예제 #14
0
def train(args, chkpt_dir, chkpt_path, writer, logger, hp, hp_str, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = (
        True  # https://github.com/pytorch/pytorch/issues/6351
    )
    if args.amp:
        print("Automatic Mixed Precision Training")
        scaler = amp.GradScaler()

    criterion = WaveFlowLoss(hp.model.sigma)
    model = WaveFlow(hp.model.flows, hp.model.n_group, hp.audio.sampling_rate,
                     hp.audio.win_length, hp.audio.n_mel_channels, hp).cuda()

    num_params(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam.lr)

    # Load checkpoint if one exists

    githash = get_commit_hash()

    init_epoch = -1
    step = 0

    if chkpt_path is not None:
        #logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optim'])
        step = checkpoint['step']
        init_epoch = checkpoint['epoch']

        if hp_str != checkpoint['hp_str']:
            logger.warning(
                "New hparams is different from checkpoint. Will use new.")

        if githash != checkpoint['githash']:
            logger.warning("Code might be different: git hash is different.")
            logger.warning("%s -> %s" % (checkpoint['githash'], githash))

    else:
        logger.info("Starting new training run.")

    train_loader = create_dataloader(hp, True)
    valid_loader = create_dataloader(hp, False)
    # Get shared output_directory ready
    stft = TacotronSTFT(filter_length=hp.audio.filter_length,
                        hop_length=hp.audio.hop_length,
                        win_length=hp.audio.win_length,
                        n_mel_channels=hp.audio.n_mel_channels,
                        sampling_rate=hp.audio.sampling_rate,
                        mel_fmin=hp.audio.mel_fmin,
                        mel_fmax=hp.audio.mel_fmax)

    model.train()
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in itertools.count(init_epoch + 1):
        if epoch % hp.log.validation_interval == 0:
            with torch.no_grad():
                pass

        loader = tqdm.tqdm(train_loader, desc='Loading train data')
        loss_list = []
        for (mel, audio) in loader:
            model.zero_grad()
            #mel, audio = batch
            mel = torch.autograd.Variable(
                mel.cuda())  # [B, num mel, num of frame]
            audio = torch.autograd.Variable(audio.cuda())  # [B, T]
            if args.amp:
                with amp.autocast():
                    z, logdet, _ = model(audio, mel)  # [B, T]
                    loss = criterion(z, logdet)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                z, logdet, _ = model(audio, mel)  # [B, T]
                loss = criterion(z, logdet)
                loss.backward()
                optimizer.step()

            loss_list.append(loss.item())

            loader.set_description("Avg Loss : loss %.04f | step %d" %
                                   (sum(loss_list) / len(loss_list), step))
            if step % hp.log.summary_interval == 0:
                writer.add_scalar("train.loss", loss.item(), step)
                writer.add_scalar('train.log_determinant',
                                  logdet.mean().item())
                writer.add_scalar('train.z_mean', z.mean().item())
                writer.add_scalar('train.z_std', z.std().item())

            if step % hp.log.validation_interval == 0:
                for (mel_, audio_) in valid_loader:
                    model.eval()
                    x, logdet_ = model.infer(mel_.cuda(),
                                             hp.model.sigma)  # x -> [T]

                    torch.clamp(x, -1, 1, out=x)
                    writer.add_scalar('valid.log_determinant',
                                      logdet_.mean().item())
                    writer.add_audio('actual_audio',
                                     audio_.squeeze(0).cpu().detach().numpy(),
                                     step,
                                     sample_rate=hp.audio.sampling_rate)
                    writer.add_audio('reconstruct_audio',
                                     x.cpu().detach().numpy(),
                                     step,
                                     sample_rate=hp.audio.sampling_rate)
                    mel_spec = mel_[0].cpu().detach()
                    mel_spec -= mel_spec.min()
                    mel_spec /= mel_spec.max()
                    writer.add_image('actual_mel-spectrum',
                                     plot_spectrogram_to_numpy(
                                         mel_spec.numpy()),
                                     step,
                                     dataformats='HWC')

                    mel_gen, _ = stft.mel_spectrogram(x.unsqueeze(0))
                    mel_g_spec = mel_gen[0].cpu().detach()
                    mel_g_spec -= mel_g_spec.min()
                    mel_g_spec /= mel_g_spec.max()
                    writer.add_image('gen_mel-spectrum',
                                     plot_spectrogram_to_numpy(
                                         mel_g_spec.numpy()),
                                     step,
                                     dataformats='HWC')
                    model.train()
                    break

            step += 1

        if epoch % hp.log.save_interval == 0:
            save_path = os.path.join(
                chkpt_dir, '%s_%s_%04d.pt' % (args.name, githash, epoch))
            torch.save(
                {
                    'model': model.state_dict(),
                    'optim': optimizer.state_dict(),
                    'step': step,
                    'epoch': epoch,
                    'hp_str': hp_str,
                    'githash': githash,
                }, save_path)
            logger.info("Saved checkpoint to: %s" % save_path)