Esempio n. 1
0
def main(meta_dir: str, save_dir: str,
         save_prefix: str, pretrained_path: str = '',
         model_name: str = 'tacotron2_base', batch_size: int = 32, num_workers: int = 16,
         lr: float = 1e-4, betas: Tuple[float] = (0.9, 0.99), weight_decay: float = 0.0,
         max_step: int = 100000, valid_max_step: int = 50, save_interval: int = 1000, log_interval: int = 50,
         grad_clip: float = 0.0, grad_norm: float = 30.0, gamma: float = 0.1, milestones: Tuple[int] = None,
         sr: int = 22050):

    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    train_loader, valid_loader = get_datasets(
        meta_dir, batch_size=batch_size, num_workers=num_workers
    )

    Tacotron2Trainer(
        model, optimizer, train_loader, valid_loader,
        max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval,
        log_interval=log_interval,
        save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm,
        pretrained_path=pretrained_path, scheduler=scheduler, sr=sr
    ).run()
Esempio n. 2
0
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module:
    print('Load model ...')
    model = build_model(model_name).cuda()
    chk = torch.load(pretrained_path)['model']
    model.load_state_dict(get_loadable_checkpoint(chk))
    model.eval()
    return model
Esempio n. 3
0
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module:
    print('Load model ...')
    model = build_model(model_name)
    chk = torch.load(pretrained_path,
                     map_location=torch.device('cpu'))['model']
    model.load_state_dict(get_loadable_checkpoint(chk))
    model.eval()
    return model
Esempio n. 4
0
    def __init__(self, model_name: str = 'hifi_gan_v1'):
        assert model_name in ['hifi_gan_v1', 'hifi_gan_v2'], \
            'Model name {} is not valid! choose in {}'.format(model_name, str(['hifi_gan_v1', 'hifi_gan_v2']))

        # encoder
        self.encoder = MelSpectrogram(**vars(AudioParameters()))

        # decoder
        self.decoder = build_model(model_name)
        chkpt = torch.load(CHECKPOINTS[model_name])
        self.decoder.load_state_dict(chkpt['generator'])
Esempio n. 5
0
    def __init__(self, model_name: str = 'hifi_gan_v1', chk_path: str = '', device='cpu'):
        assert model_name in ['hifi_gan_v1', 'hifi_gan_v2', 'hifi_gan_v3'], \
            'Model name {} is not valid! choose in {}'.format(
                model_name, str(['hifi_gan_v1', 'hifi_gan_v2', 'hifi_gan_v3']))

        # encoder
        self.encoder = MelSpectrogram(**vars(AudioParameters())).to(device)

        # decoder
        self.decoder = build_model(model_name).to(device)
        if chk_path:
            chkpt = torch.load(chk_path)
        else:
            chkpt = torch.load(CHECKPOINTS[model_name])
        self.decoder.load_state_dict(chkpt['generator'])
        self.decoder.remove_weight_norm()
Esempio n. 6
0
def main(meta_dir: str, save_dir: str,
         save_prefix: str, pretrained_path: str = '',
         model_name: str = 'refine_unet_base', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2.,
         lr: float = 5e-4, betas: Tuple[float] = (0.5, 0.9), weight_decay: float = 0.0,
         max_step: int = 200000, valid_max_step: int = 30, save_interval: int = 1000, log_interval: int = 50,
         grad_clip: float = 0.0, grad_norm: float = 30.0,
         is_augment: bool = True, milestones: Tuple[int] = None, gamma: float = 0.1,
         case_name: str = 'voice_bank', mix_loss: bool = False):

    # check args
    assert os.path.exists(meta_dir)

    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    # handle cases
    train_loader, valid_loader, sr = handle_cases(case_name, is_augment, meta_dir, batch_size, num_workers, fix_len)

    if mix_loss:
        trainer = LossMixingTrainer
    else:
        trainer = Wave2WaveTrainer

    # train
    trainer(
        model, optimizer, train_loader, valid_loader,
        max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval,
        log_interval=log_interval,
        save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm,
        pretrained_path=pretrained_path, scheduler=scheduler, sr=sr
    ).run()
Esempio n. 7
0
def main(meta_dir: str,
         save_dir: str,
         save_prefix: str,
         pretrained_path: str = '',
         model_name: str = 'refine_unet_base',
         batch_size: int = 128,
         num_workers: int = 16,
         fix_len: float = 2.,
         lr: float = 5e-4,
         betas: Tuple[float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         max_step: int = 200000,
         valid_max_step: int = 30,
         save_interval: int = 1000,
         log_interval: int = 50,
         grad_clip: float = 0.0,
         grad_norm: float = 30.0,
         is_augment: bool = True,
         milestones: Tuple[int] = None,
         gamma: float = 0.1,
         is_dsd: bool = False,
         mix_loss: bool = False):

    # check args
    assert os.path.exists(meta_dir)

    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=betas,
                                 weight_decay=weight_decay)
    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    # adopt dsd100 case
    if is_dsd:
        sr = 44100
        if is_augment:
            dataset_func = get_datasets
            meta_cls = DSD100Meta
        else:
            dataset_func = dsd100.get_datasets
        is_audioset = False
    else:
        sr = 22050
        # load dataset
        if is_augment:
            dataset_func = get_datasets
            meta_cls = VoiceBankMeta
            is_audioset = True
        else:
            dataset_func = voice_bank.get_datasets
            is_audioset = False

    train_loader, valid_loader = dataset_func(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              meta_cls=meta_cls,
                                              is_audioset=is_audioset,
                                              fix_len=int(fix_len * sr),
                                              audio_mask=True)

    if mix_loss:
        trainer = LossMixingTrainer
    else:
        trainer = Wave2WaveTrainer

    # train
    trainer(model,
            optimizer,
            train_loader,
            valid_loader,
            max_step=max_step,
            valid_max_step=min(valid_max_step, len(valid_loader)),
            save_interval=save_interval,
            log_interval=log_interval,
            save_dir=save_dir,
            save_prefix=save_prefix,
            grad_clip=grad_clip,
            grad_norm=grad_norm,
            pretrained_path=pretrained_path,
            scheduler=scheduler,
            sr=sr).run()
Esempio n. 8
0
def main(vb_meta_dir: str,
         music_meta_dir: str,
         save_dir: str,
         save_prefix: str,
         pretrained_path: str = '',
         model_name: str = 'refine_unet_larger',
         batch_size: int = 128,
         num_workers: int = 16,
         fix_len: float = 2.,
         lr: float = 5e-4,
         betas: Tuple[float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         max_step: int = 200000,
         valid_max_step: int = 50,
         save_interval: int = 1000,
         log_interval: int = 50,
         grad_clip: float = 0.0,
         grad_norm: float = 30.0,
         milestones: Tuple[int] = None,
         gamma: float = 0.1,
         sample_rate: int = 44100,
         music_data_name: str = 'dsd100',
         mix_loss: bool = False):

    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=betas,
                                 weight_decay=weight_decay)
    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    # choose music source dataset
    if music_data_name == 'musdb18':
        MusicMeta = MUSDB18Meta
    elif music_data_name == 'dsd100':
        MusicMeta = DSD100Meta
    else:
        raise NotImplementedError(
            f'{music_data_name} is not implemented ! choose one in [\'musdb18\', \'dsd100\']'
        )

    # make metas
    meta_dir_list = [vb_meta_dir, music_meta_dir]
    meta_cls_list = [VoiceBankMeta, MusicMeta]
    meta_dir_list, meta_cls_list = map(
        list,
        zip(*[(d, cls) for d, cls in zip(meta_dir_list, meta_cls_list) if d]))

    train_loader, valid_loader = get_concated_datasets(
        meta_dir_list,
        batch_size=batch_size,
        num_workers=num_workers,
        meta_cls_list=meta_cls_list,
        fix_len=int(fix_len * sample_rate),
        audio_mask=True)

    if mix_loss:
        trainer = LossMixingTrainer
    else:
        trainer = Wave2WaveTrainer

    # train
    trainer(model,
            optimizer,
            train_loader,
            valid_loader,
            max_step=max_step,
            valid_max_step=min(valid_max_step, len(valid_loader)),
            save_interval=save_interval,
            log_interval=log_interval,
            save_dir=save_dir,
            save_prefix=save_prefix,
            grad_clip=grad_clip,
            grad_norm=grad_norm,
            pretrained_path=pretrained_path,
            scheduler=scheduler,
            sr=sample_rate).run()
Esempio n. 9
0
def main(meta_dir: str,
         save_dir: str,
         save_prefix: str,
         pretrained_path: str = '',
         batch_size: int = 32,
         num_workers: int = 8,
         lr: float = 1e-4,
         betas: Tuple[float, float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         pretrain_step: int = 200000,
         max_step: int = 1000000,
         save_interval: int = 10000,
         log_scala_interval: int = 20,
         log_heavy_interval: int = 1000,
         gamma: float = 0.5,
         seed: int = 1234):
    #
    # prepare training
    #
    # create model
    mb_generator = build_model('generator_mb').cuda()
    discriminator = build_model('discriminator_base').cuda()

    # Multi-gpu is not required.

    # create optimizers
    mb_opt = torch.optim.Adam(mb_generator.parameters(),
                              lr=lr,
                              betas=betas,
                              weight_decay=weight_decay)
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=betas,
                               weight_decay=weight_decay)

    # make scheduler
    mb_scheduler = MultiStepLR(mb_opt,
                               list(range(300000, 900000 + 1, 100000)),
                               gamma=gamma)
    dis_scheduler = MultiStepLR(dis_opt,
                                list(range(100000, 700000 + 1, 100000)),
                                gamma=gamma)

    # get datasets
    train_loader, valid_loader = get_datasets(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              crop_length=settings.SAMPLE_RATE,
                                              random_seed=seed)

    # repeat
    train_loader = repeat(train_loader)

    # build mel function
    mel_func, stft_funcs_for_loss = build_stft_functions()

    # build pqmf
    pqmf_func = PQMF().cuda()

    # prepare logging
    writer, model_dir = prepare_logging(save_dir, save_prefix)

    # Training Saving Attributes
    best_loss = np.finfo(np.float32).max
    initial_step = 0

    # load model
    if pretrained_path:
        log(f'Pretrained path is given : {pretrained_path} . Loading...')
        chk = torch.load(pretrained_path)
        gen_chk, dis_chk = chk['generator'], chk['discriminator']
        gen_opt_chk, dis_opt_chk = chk['gen_opt'], chk['dis_opt']
        initial_step = int(chk['step'])
        l = chk['loss']

        mb_generator.load_state_dict(gen_chk)
        discriminator.load_state_dict(dis_chk)
        mb_opt.load_state_dict(gen_opt_chk)
        dis_opt.load_state_dict(dis_opt_chk)
        if 'dis_scheduler' in chk:
            dis_scheduler_chk = chk['dis_scheduler']
            gen_scheduler_chk = chk['gen_scheduler']
            mb_scheduler.load_state_dict(gen_scheduler_chk)
            dis_scheduler.load_state_dict(dis_scheduler_chk)

        mb_opt._step_count = initial_step
        mb_scheduler._step_count = initial_step
        dis_opt._step_count = initial_step - pretrain_step
        dis_scheduler._step_count = initial_step - pretrain_step

        mb_scheduler.step(initial_step)
        dis_scheduler.step(initial_step - pretrain_step)
        best_loss = l

    #
    # Training !
    #
    # Pretraining generator
    for step in range(initial_step, pretrain_step):
        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        # get multi-resolution stft loss   eq 9)
        loss, mb_loss, fb_loss = get_stft_loss(pred, wav, pred_subbands,
                                               target_subbands,
                                               stft_funcs_for_loss)

        # backward and update
        loss.backward()
        mb_opt.step()
        mb_scheduler.step()

        mb_opt.zero_grad()
        mb_generator.zero_grad()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/pretrain_loss',
                              loss.item(),
                              global_step=step)
            writer.add_scalar('train/mb_loss',
                              mb_loss.item(),
                              global_step=step)
            writer.add_scalar('train/fb_loss',
                              fb_loss.item(),
                              global_step=step)

            if step % log_heavy_interval == 0:
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss: {loss.item()} / mb_loss: {mb_loss.item()} / fb_loss: {fb_loss.item()}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_loss = 0.
            valid_mb_loss, valid_fb_loss = 0., 0.
            count = 0
            mb_generator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # forward
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # get stft loss
                    loss, mb_loss, fb_loss = get_stft_loss(
                        pred, wav, pred_subbands, target_subbands,
                        stft_funcs_for_loss)

                valid_loss += loss.item()
                valid_mb_loss += mb_loss.item()
                valid_fb_loss += fb_loss.item()
                count = idx

            valid_loss /= (count + 1)
            valid_mb_loss /= (count + 1)
            valid_fb_loss /= (count + 1)
            mb_generator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('valid/pretrain_loss',
                              valid_loss,
                              global_step=step)
            writer.add_scalar('valid/mb_loss', valid_mb_loss, global_step=step)
            writer.add_scalar('valid/fb_loss', valid_fb_loss, global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- Valid loss: {valid_loss} / mb_loss: {valid_mb_loss} / fb_loss: {valid_fb_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_loss < best_loss
            if is_best:
                best_loss = valid_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_loss,
                            is_best=is_best)

    #
    # Train GAN
    #
    dis_block_layers = 6
    lambda_gen = 2.5
    best_loss = np.finfo(np.float32).max

    for step in range(max(pretrain_step, initial_step), max_step):

        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        #
        # Train Discriminator
        #

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        with torch.no_grad():
            pred_mel = mel_func(pred.squeeze(1).detach())
            mel_err = F.l1_loss(mel, pred_mel).item()

        # if terminate_step > step:
        d_fake_det = discriminator(pred.detach())
        d_real = discriminator(wav.unsqueeze(1))

        # calculate discriminator losses  eq 1)
        loss_D = 0

        for idx in range(dis_block_layers - 1, len(d_fake_det),
                         dis_block_layers):
            loss_D += torch.mean((d_fake_det[idx] - 1)**2)

        for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers):
            loss_D += torch.mean(d_real[idx]**2)

        # train
        discriminator.zero_grad()
        loss_D.backward()
        dis_opt.step()
        dis_scheduler.step()

        #
        # Train Generator
        #
        d_fake = discriminator(pred)

        # calc generator loss   eq 8)
        loss_G = 0
        for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers):
            loss_G += ((d_fake[idx] - 1)**2).mean()

        loss_G *= lambda_gen

        # get multi-resolution stft loss
        loss_G += get_stft_loss(pred, wav, pred_subbands, target_subbands,
                                stft_funcs_for_loss)[0]
        # loss_G += get_spec_losses(pred, wav, stft_funcs_for_loss)[0]

        mb_generator.zero_grad()
        loss_G.backward()
        mb_opt.step()
        mb_scheduler.step()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/loss_G', loss_G.item(), global_step=step)
            writer.add_scalar('train/loss_D', loss_D.item(), global_step=step)
            writer.add_scalar('train/mel_err', mel_err, global_step=step)
            if step % log_heavy_interval == 0:
                target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
                pred_mel = imshow_to_buf(
                    mel_func(pred[:1, 0])[0].detach().cpu().numpy())

                writer.add_image('train/target_mel',
                                 target_mel,
                                 global_step=step)
                writer.add_image('train/pred_mel', pred_mel, global_step=step)
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss_G: {loss_G.item()} / loss_D: {loss_D.item()} / ' \
                f' mel_err: {mel_err}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_g_loss, valid_d_loss, valid_mel_loss = 0., 0., 0.
            count = 0
            mb_generator.eval()
            discriminator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # Discriminator
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # Mel Error
                    pred_mel = mel_func(pred.squeeze(1).detach())
                    mel_err = F.l1_loss(mel, pred_mel).item()

                    #
                    # discriminator part
                    #
                    d_fake_det = discriminator(pred.detach())
                    d_real = discriminator(wav.unsqueeze(1))

                    loss_D = 0

                    for idx in range(dis_block_layers - 1, len(d_fake_det),
                                     dis_block_layers):
                        loss_D += torch.mean((d_fake_det[idx] - 1)**2)

                    for idx in range(dis_block_layers - 1, len(d_real),
                                     dis_block_layers):
                        loss_D += torch.mean(d_real[idx]**2)

                    #
                    # generator part
                    #
                    d_fake = discriminator(pred)

                    # calc generator loss
                    loss_G = 0
                    for idx in range(dis_block_layers - 1, len(d_fake),
                                     dis_block_layers):
                        loss_G += ((d_fake[idx] - 1)**2).mean()

                    loss_G *= lambda_gen

                    # get stft loss
                    stft_loss = get_stft_loss(pred, wav, pred_subbands,
                                              target_subbands,
                                              stft_funcs_for_loss)[0]
                    loss_G += stft_loss

                valid_d_loss += loss_D.item()
                valid_g_loss += loss_G.item()
                valid_mel_loss += mel_err
                count = idx

            valid_d_loss /= (count + 1)
            valid_g_loss /= (count + 1)
            valid_mel_loss /= (count + 1)

            mb_generator.train()
            discriminator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
            pred_mel = imshow_to_buf(
                mel_func(pred[:1, 0])[0].detach().cpu().numpy())

            writer.add_image('valid/target_mel', target_mel, global_step=step)
            writer.add_image('valid/pred_mel', pred_mel, global_step=step)
            writer.add_scalar('valid/loss_G', valid_g_loss, global_step=step)
            writer.add_scalar('valid/loss_D', valid_d_loss, global_step=step)
            writer.add_scalar('valid/mel_err',
                              valid_mel_loss,
                              global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- loss_G: {valid_g_loss} / loss_D: {valid_d_loss} / mel loss : {valid_mel_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_g_loss < best_loss
            if is_best:
                best_loss = valid_g_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_g_loss,
                            is_best=is_best)

    log('----- Finish ! -----')
Esempio n. 10
0
def _main(
    meta_dir: str,
    save_prefix: str = '',
    model_name: str = 'refine_unet_base',  # or refine_spectrogram_unet
    save_dir: str = 'savedir',
    batch_size: int = 128,
    num_workers: int = 16,
    fix_len: float = 2.,
    lr: float = 5e-4,
    beta1: float = 0.5,
    beta2: float = 0.9,
    weight_decay: float = 0.0,
    max_step: int = 100000,
    valid_max_step: int = 30,
    save_interval: int = 1000,
    log_interval: int = 100,
    grad_clip: float = 0.0,
    grad_norm: float = 30.0,
    milestones: Tuple[int] = None,
    gamma: float = 0.2,
    is_augment: bool = True,
    is_dsd: bool = False,
    # model args
    hidden_dim: int = 768,
    filter_len: int = 512,
    hop_len: int = 64,
    block_layers: int = 4,
    layers: int = 4,
    kernel_size: int = 3,
    norm: str = 'ins',
    act: str = 'comp',
    refine_layers: int = 1,
):
    betas = beta1, beta2

    # setup model args
    model_args = {
        'hidden_dim': hidden_dim,
        'filter_len': filter_len,
        'hop_len': hop_len,
        'spec_dim': filter_len // 2 + 1,
        'block_layers': block_layers,
        'layers': layers,
        'kernel_size': kernel_size,
        'norm': norm,
        'refine_layers': refine_layers,
        'act': act
    }

    # create model
    model = build_model(model_name, extra_kwargs=model_args).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=betas,
                                 weight_decay=weight_decay)

    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    # adopt dsd100 case
    if is_dsd:
        sr = 44100
        if is_augment:
            dataset_func = get_datasets
            meta_cls = DSD100Meta
        else:
            dataset_func = dsd100.get_datasets
    else:
        sr = 22050
        # load dataset
        if is_augment:
            dataset_func = get_datasets
            meta_cls = VoiceBankMeta
        else:
            dataset_func = voice_bank.get_datasets

    train_loader, valid_loader = dataset_func(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              meta_cls=meta_cls,
                                              fix_len=int(fix_len * sr),
                                              audio_mask=True)

    # train
    loss = Wave2WaveTrainer(model,
                            optimizer,
                            train_loader,
                            valid_loader,
                            max_step=max_step,
                            valid_max_step=min(valid_max_step,
                                               len(valid_loader)),
                            save_interval=save_interval,
                            log_interval=log_interval,
                            save_dir=save_dir,
                            save_prefix=save_prefix,
                            grad_clip=grad_clip,
                            grad_norm=grad_norm,
                            pretrained_path='',
                            scheduler=scheduler,
                            sr=sr).run()

    return {
        'loss': loss,
        'status': 'ok',
    }
Esempio n. 11
0
def main(train_path: str,
         preprocessed_path: str,
         save_dir: str,
         save_prefix: str,
         model_name: str,
         pretrained_path: str = '',
         num_workers: int = 16,
         batch_size: int = 16,
         pitch_feature: str = 'phoneme',
         energy_feature: str = 'phoneme',
         pitch_min: float = 0.,
         energy_min: float = 0.,
         lr: float = 2e-4,
         weight_decay: float = 0.0001,
         betas=(0.9, 0.98),
         max_step: int = 400000,
         group_size: int = 4,
         save_interval: int = 10000,
         log_interval: int = 50,
         grad_clip: float = 0.0,
         grad_norm: float = 5.0,
         milestones: Tuple[int] = None,
         gamma: float = 0.2,
         sr: int = 22050,
         seed: int = 2021,
         is_reference: bool = False):
    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=lr,
                                  betas=betas,
                                  weight_decay=weight_decay)
    if milestones:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones,
                                                         gamma=gamma)
    else:
        scheduler = None

    dataset = Dataset(train_path,
                      preprocessed_path,
                      pitch_min=pitch_min,
                      energy_min=energy_min,
                      text_cleaners=['english_cleaners'],
                      batch_size=batch_size,
                      sort=True,
                      drop_last=True,
                      is_reference=is_reference)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size * group_size,
                              shuffle=True,
                              collate_fn=dataset.collate_fn,
                              num_workers=num_workers)

    # train
    BaseTrainer(model,
                optimizer,
                train_loader,
                None,
                max_step=max_step,
                valid_max_step=1,
                save_interval=save_interval,
                log_interval=log_interval,
                pitch_feature=pitch_feature,
                energy_feature=energy_feature,
                save_dir=save_dir,
                save_prefix=save_prefix,
                grad_clip=grad_clip,
                grad_norm=grad_norm,
                pretrained_path=pretrained_path,
                sr=sr,
                scheduler=scheduler,
                seed=seed,
                is_reference=is_reference).run()
Esempio n. 12
0
def main(meta_dir: str,
         save_dir: str,
         save_prefix: str = '',
         pretrained_path: str = '',
         model_name: str = 'refine_unet_base',
         batch_size: int = 128,
         num_workers: int = 16,
         fix_len: float = 2.,
         lr: float = 3e-4,
         betas: Tuple[float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         max_step: int = 100000,
         valid_max_step: int = 30,
         save_interval: int = 1000,
         log_interval: int = 50,
         grad_clip: float = 0.0,
         grad_norm: float = 30.0,
         is_audioset: bool = True,
         milestones: Tuple[int] = None,
         gamma: float = 0.1):

    # check args
    assert os.path.exists(meta_dir)

    # create model
    model = build_model(model_name).cuda()

    # multi-gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # create optimizers
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=betas,
                                 weight_decay=weight_decay)
    if milestones:
        milestones = [int(x) for x in list(milestones)]
        scheduler = MultiStepLR(optimizer, milestones, gamma=gamma)
    else:
        scheduler = None

    # load dataset
    if is_audioset:
        dataset_func = get_datasets
    else:
        dataset_func = voice_bank.get_datasets

    train_loader, valid_loader = dataset_func(
        meta_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        fix_len=int(fix_len * settings.SAMPLE_RATE),
        audio_mask=True)

    # train
    Wave2WaveTrainer(model,
                     optimizer,
                     train_loader,
                     valid_loader,
                     max_step=max_step,
                     valid_max_step=min(valid_max_step, len(valid_loader)),
                     save_interval=save_interval,
                     log_interval=log_interval,
                     save_dir=save_dir,
                     save_prefix=save_prefix,
                     grad_clip=grad_clip,
                     grad_norm=grad_norm,
                     pretrained_path=pretrained_path,
                     scheduler=scheduler).run()