def main(meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', model_name: str = 'tacotron2_base', batch_size: int = 32, num_workers: int = 16, lr: float = 1e-4, betas: Tuple[float] = (0.9, 0.99), weight_decay: float = 0.0, max_step: int = 100000, valid_max_step: int = 50, save_interval: int = 1000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 30.0, gamma: float = 0.1, milestones: Tuple[int] = None, sr: int = 22050): # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None train_loader, valid_loader = get_datasets( meta_dir, batch_size=batch_size, num_workers=num_workers ) Tacotron2Trainer( model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, scheduler=scheduler, sr=sr ).run()
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module: print('Load model ...') model = build_model(model_name).cuda() chk = torch.load(pretrained_path)['model'] model.load_state_dict(get_loadable_checkpoint(chk)) model.eval() return model
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module: print('Load model ...') model = build_model(model_name) chk = torch.load(pretrained_path, map_location=torch.device('cpu'))['model'] model.load_state_dict(get_loadable_checkpoint(chk)) model.eval() return model
def __init__(self, model_name: str = 'hifi_gan_v1'): assert model_name in ['hifi_gan_v1', 'hifi_gan_v2'], \ 'Model name {} is not valid! choose in {}'.format(model_name, str(['hifi_gan_v1', 'hifi_gan_v2'])) # encoder self.encoder = MelSpectrogram(**vars(AudioParameters())) # decoder self.decoder = build_model(model_name) chkpt = torch.load(CHECKPOINTS[model_name]) self.decoder.load_state_dict(chkpt['generator'])
def __init__(self, model_name: str = 'hifi_gan_v1', chk_path: str = '', device='cpu'): assert model_name in ['hifi_gan_v1', 'hifi_gan_v2', 'hifi_gan_v3'], \ 'Model name {} is not valid! choose in {}'.format( model_name, str(['hifi_gan_v1', 'hifi_gan_v2', 'hifi_gan_v3'])) # encoder self.encoder = MelSpectrogram(**vars(AudioParameters())).to(device) # decoder self.decoder = build_model(model_name).to(device) if chk_path: chkpt = torch.load(chk_path) else: chkpt = torch.load(CHECKPOINTS[model_name]) self.decoder.load_state_dict(chkpt['generator']) self.decoder.remove_weight_norm()
def main(meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', model_name: str = 'refine_unet_base', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2., lr: float = 5e-4, betas: Tuple[float] = (0.5, 0.9), weight_decay: float = 0.0, max_step: int = 200000, valid_max_step: int = 30, save_interval: int = 1000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 30.0, is_augment: bool = True, milestones: Tuple[int] = None, gamma: float = 0.1, case_name: str = 'voice_bank', mix_loss: bool = False): # check args assert os.path.exists(meta_dir) # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None # handle cases train_loader, valid_loader, sr = handle_cases(case_name, is_augment, meta_dir, batch_size, num_workers, fix_len) if mix_loss: trainer = LossMixingTrainer else: trainer = Wave2WaveTrainer # train trainer( model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, scheduler=scheduler, sr=sr ).run()
def main(meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', model_name: str = 'refine_unet_base', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2., lr: float = 5e-4, betas: Tuple[float] = (0.5, 0.9), weight_decay: float = 0.0, max_step: int = 200000, valid_max_step: int = 30, save_interval: int = 1000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 30.0, is_augment: bool = True, milestones: Tuple[int] = None, gamma: float = 0.1, is_dsd: bool = False, mix_loss: bool = False): # check args assert os.path.exists(meta_dir) # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None # adopt dsd100 case if is_dsd: sr = 44100 if is_augment: dataset_func = get_datasets meta_cls = DSD100Meta else: dataset_func = dsd100.get_datasets is_audioset = False else: sr = 22050 # load dataset if is_augment: dataset_func = get_datasets meta_cls = VoiceBankMeta is_audioset = True else: dataset_func = voice_bank.get_datasets is_audioset = False train_loader, valid_loader = dataset_func(meta_dir, batch_size=batch_size, num_workers=num_workers, meta_cls=meta_cls, is_audioset=is_audioset, fix_len=int(fix_len * sr), audio_mask=True) if mix_loss: trainer = LossMixingTrainer else: trainer = Wave2WaveTrainer # train trainer(model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, scheduler=scheduler, sr=sr).run()
def main(vb_meta_dir: str, music_meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', model_name: str = 'refine_unet_larger', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2., lr: float = 5e-4, betas: Tuple[float] = (0.5, 0.9), weight_decay: float = 0.0, max_step: int = 200000, valid_max_step: int = 50, save_interval: int = 1000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 30.0, milestones: Tuple[int] = None, gamma: float = 0.1, sample_rate: int = 44100, music_data_name: str = 'dsd100', mix_loss: bool = False): # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None # choose music source dataset if music_data_name == 'musdb18': MusicMeta = MUSDB18Meta elif music_data_name == 'dsd100': MusicMeta = DSD100Meta else: raise NotImplementedError( f'{music_data_name} is not implemented ! choose one in [\'musdb18\', \'dsd100\']' ) # make metas meta_dir_list = [vb_meta_dir, music_meta_dir] meta_cls_list = [VoiceBankMeta, MusicMeta] meta_dir_list, meta_cls_list = map( list, zip(*[(d, cls) for d, cls in zip(meta_dir_list, meta_cls_list) if d])) train_loader, valid_loader = get_concated_datasets( meta_dir_list, batch_size=batch_size, num_workers=num_workers, meta_cls_list=meta_cls_list, fix_len=int(fix_len * sample_rate), audio_mask=True) if mix_loss: trainer = LossMixingTrainer else: trainer = Wave2WaveTrainer # train trainer(model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, scheduler=scheduler, sr=sample_rate).run()
def main(meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', batch_size: int = 32, num_workers: int = 8, lr: float = 1e-4, betas: Tuple[float, float] = (0.5, 0.9), weight_decay: float = 0.0, pretrain_step: int = 200000, max_step: int = 1000000, save_interval: int = 10000, log_scala_interval: int = 20, log_heavy_interval: int = 1000, gamma: float = 0.5, seed: int = 1234): # # prepare training # # create model mb_generator = build_model('generator_mb').cuda() discriminator = build_model('discriminator_base').cuda() # Multi-gpu is not required. # create optimizers mb_opt = torch.optim.Adam(mb_generator.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) dis_opt = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) # make scheduler mb_scheduler = MultiStepLR(mb_opt, list(range(300000, 900000 + 1, 100000)), gamma=gamma) dis_scheduler = MultiStepLR(dis_opt, list(range(100000, 700000 + 1, 100000)), gamma=gamma) # get datasets train_loader, valid_loader = get_datasets(meta_dir, batch_size=batch_size, num_workers=num_workers, crop_length=settings.SAMPLE_RATE, random_seed=seed) # repeat train_loader = repeat(train_loader) # build mel function mel_func, stft_funcs_for_loss = build_stft_functions() # build pqmf pqmf_func = PQMF().cuda() # prepare logging writer, model_dir = prepare_logging(save_dir, save_prefix) # Training Saving Attributes best_loss = np.finfo(np.float32).max initial_step = 0 # load model if pretrained_path: log(f'Pretrained path is given : {pretrained_path} . Loading...') chk = torch.load(pretrained_path) gen_chk, dis_chk = chk['generator'], chk['discriminator'] gen_opt_chk, dis_opt_chk = chk['gen_opt'], chk['dis_opt'] initial_step = int(chk['step']) l = chk['loss'] mb_generator.load_state_dict(gen_chk) discriminator.load_state_dict(dis_chk) mb_opt.load_state_dict(gen_opt_chk) dis_opt.load_state_dict(dis_opt_chk) if 'dis_scheduler' in chk: dis_scheduler_chk = chk['dis_scheduler'] gen_scheduler_chk = chk['gen_scheduler'] mb_scheduler.load_state_dict(gen_scheduler_chk) dis_scheduler.load_state_dict(dis_scheduler_chk) mb_opt._step_count = initial_step mb_scheduler._step_count = initial_step dis_opt._step_count = initial_step - pretrain_step dis_scheduler._step_count = initial_step - pretrain_step mb_scheduler.step(initial_step) dis_scheduler.step(initial_step - pretrain_step) best_loss = l # # Training ! # # Pretraining generator for step in range(initial_step, pretrain_step): # data wav, _ = next(train_loader) wav = wav.cuda() # to mel mel = mel_func(wav) # pqmf target_subbands = pqmf_func.analysis(wav.unsqueeze(1)) # N, SUBBAND, T # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # get multi-resolution stft loss eq 9) loss, mb_loss, fb_loss = get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss) # backward and update loss.backward() mb_opt.step() mb_scheduler.step() mb_opt.zero_grad() mb_generator.zero_grad() # # logging! save! # if step % log_scala_interval == 0 and step > 0: # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('train/pretrain_loss', loss.item(), global_step=step) writer.add_scalar('train/mb_loss', mb_loss.item(), global_step=step) writer.add_scalar('train/fb_loss', fb_loss.item(), global_step=step) if step % log_heavy_interval == 0: writer.add_audio('train/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('train/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console msg = f'train: step: {step} / loss: {loss.item()} / mb_loss: {mb_loss.item()} / fb_loss: {fb_loss.item()}' log(msg) if step % save_interval == 0 and step > 0: # # Validation Step ! # valid_loss = 0. valid_mb_loss, valid_fb_loss = 0., 0. count = 0 mb_generator.eval() for idx, (wav, _) in enumerate(valid_loader): # setup data wav = wav.cuda() mel = mel_func(wav) with torch.no_grad(): # pqmf target_subbands = pqmf_func.analysis( wav.unsqueeze(1)) # N, SUBBAND, T # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # get stft loss loss, mb_loss, fb_loss = get_stft_loss( pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss) valid_loss += loss.item() valid_mb_loss += mb_loss.item() valid_fb_loss += fb_loss.item() count = idx valid_loss /= (count + 1) valid_mb_loss /= (count + 1) valid_fb_loss /= (count + 1) mb_generator.train() # log validation # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('valid/pretrain_loss', valid_loss, global_step=step) writer.add_scalar('valid/mb_loss', valid_mb_loss, global_step=step) writer.add_scalar('valid/fb_loss', valid_fb_loss, global_step=step) writer.add_audio('valid/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('valid/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console log(f'---- Valid loss: {valid_loss} / mb_loss: {valid_mb_loss} / fb_loss: {valid_fb_loss} ----' ) # # save checkpoint # is_best = valid_loss < best_loss if is_best: best_loss = valid_loss save_checkpoint(mb_generator, discriminator, mb_opt, dis_opt, mb_scheduler, dis_scheduler, model_dir, step, valid_loss, is_best=is_best) # # Train GAN # dis_block_layers = 6 lambda_gen = 2.5 best_loss = np.finfo(np.float32).max for step in range(max(pretrain_step, initial_step), max_step): # data wav, _ = next(train_loader) wav = wav.cuda() # to mel mel = mel_func(wav) # pqmf target_subbands = pqmf_func.analysis(wav.unsqueeze(1)) # N, SUBBAND, T # # Train Discriminator # # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) with torch.no_grad(): pred_mel = mel_func(pred.squeeze(1).detach()) mel_err = F.l1_loss(mel, pred_mel).item() # if terminate_step > step: d_fake_det = discriminator(pred.detach()) d_real = discriminator(wav.unsqueeze(1)) # calculate discriminator losses eq 1) loss_D = 0 for idx in range(dis_block_layers - 1, len(d_fake_det), dis_block_layers): loss_D += torch.mean((d_fake_det[idx] - 1)**2) for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers): loss_D += torch.mean(d_real[idx]**2) # train discriminator.zero_grad() loss_D.backward() dis_opt.step() dis_scheduler.step() # # Train Generator # d_fake = discriminator(pred) # calc generator loss eq 8) loss_G = 0 for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers): loss_G += ((d_fake[idx] - 1)**2).mean() loss_G *= lambda_gen # get multi-resolution stft loss loss_G += get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss)[0] # loss_G += get_spec_losses(pred, wav, stft_funcs_for_loss)[0] mb_generator.zero_grad() loss_G.backward() mb_opt.step() mb_scheduler.step() # # logging! save! # if step % log_scala_interval == 0 and step > 0: # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('train/loss_G', loss_G.item(), global_step=step) writer.add_scalar('train/loss_D', loss_D.item(), global_step=step) writer.add_scalar('train/mel_err', mel_err, global_step=step) if step % log_heavy_interval == 0: target_mel = imshow_to_buf(mel[0].detach().cpu().numpy()) pred_mel = imshow_to_buf( mel_func(pred[:1, 0])[0].detach().cpu().numpy()) writer.add_image('train/target_mel', target_mel, global_step=step) writer.add_image('train/pred_mel', pred_mel, global_step=step) writer.add_audio('train/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('train/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console msg = f'train: step: {step} / loss_G: {loss_G.item()} / loss_D: {loss_D.item()} / ' \ f' mel_err: {mel_err}' log(msg) if step % save_interval == 0 and step > 0: # # Validation Step ! # valid_g_loss, valid_d_loss, valid_mel_loss = 0., 0., 0. count = 0 mb_generator.eval() discriminator.eval() for idx, (wav, _) in enumerate(valid_loader): # setup data wav = wav.cuda() mel = mel_func(wav) with torch.no_grad(): # pqmf target_subbands = pqmf_func.analysis( wav.unsqueeze(1)) # N, SUBBAND, T # Discriminator pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # Mel Error pred_mel = mel_func(pred.squeeze(1).detach()) mel_err = F.l1_loss(mel, pred_mel).item() # # discriminator part # d_fake_det = discriminator(pred.detach()) d_real = discriminator(wav.unsqueeze(1)) loss_D = 0 for idx in range(dis_block_layers - 1, len(d_fake_det), dis_block_layers): loss_D += torch.mean((d_fake_det[idx] - 1)**2) for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers): loss_D += torch.mean(d_real[idx]**2) # # generator part # d_fake = discriminator(pred) # calc generator loss loss_G = 0 for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers): loss_G += ((d_fake[idx] - 1)**2).mean() loss_G *= lambda_gen # get stft loss stft_loss = get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss)[0] loss_G += stft_loss valid_d_loss += loss_D.item() valid_g_loss += loss_G.item() valid_mel_loss += mel_err count = idx valid_d_loss /= (count + 1) valid_g_loss /= (count + 1) valid_mel_loss /= (count + 1) mb_generator.train() discriminator.train() # log validation # log writer pred_audio = pred[0, 0] target_audio = wav[0] target_mel = imshow_to_buf(mel[0].detach().cpu().numpy()) pred_mel = imshow_to_buf( mel_func(pred[:1, 0])[0].detach().cpu().numpy()) writer.add_image('valid/target_mel', target_mel, global_step=step) writer.add_image('valid/pred_mel', pred_mel, global_step=step) writer.add_scalar('valid/loss_G', valid_g_loss, global_step=step) writer.add_scalar('valid/loss_D', valid_d_loss, global_step=step) writer.add_scalar('valid/mel_err', valid_mel_loss, global_step=step) writer.add_audio('valid/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('valid/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console log(f'---- loss_G: {valid_g_loss} / loss_D: {valid_d_loss} / mel loss : {valid_mel_loss} ----' ) # # save checkpoint # is_best = valid_g_loss < best_loss if is_best: best_loss = valid_g_loss save_checkpoint(mb_generator, discriminator, mb_opt, dis_opt, mb_scheduler, dis_scheduler, model_dir, step, valid_g_loss, is_best=is_best) log('----- Finish ! -----')
def _main( meta_dir: str, save_prefix: str = '', model_name: str = 'refine_unet_base', # or refine_spectrogram_unet save_dir: str = 'savedir', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2., lr: float = 5e-4, beta1: float = 0.5, beta2: float = 0.9, weight_decay: float = 0.0, max_step: int = 100000, valid_max_step: int = 30, save_interval: int = 1000, log_interval: int = 100, grad_clip: float = 0.0, grad_norm: float = 30.0, milestones: Tuple[int] = None, gamma: float = 0.2, is_augment: bool = True, is_dsd: bool = False, # model args hidden_dim: int = 768, filter_len: int = 512, hop_len: int = 64, block_layers: int = 4, layers: int = 4, kernel_size: int = 3, norm: str = 'ins', act: str = 'comp', refine_layers: int = 1, ): betas = beta1, beta2 # setup model args model_args = { 'hidden_dim': hidden_dim, 'filter_len': filter_len, 'hop_len': hop_len, 'spec_dim': filter_len // 2 + 1, 'block_layers': block_layers, 'layers': layers, 'kernel_size': kernel_size, 'norm': norm, 'refine_layers': refine_layers, 'act': act } # create model model = build_model(model_name, extra_kwargs=model_args).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None # adopt dsd100 case if is_dsd: sr = 44100 if is_augment: dataset_func = get_datasets meta_cls = DSD100Meta else: dataset_func = dsd100.get_datasets else: sr = 22050 # load dataset if is_augment: dataset_func = get_datasets meta_cls = VoiceBankMeta else: dataset_func = voice_bank.get_datasets train_loader, valid_loader = dataset_func(meta_dir, batch_size=batch_size, num_workers=num_workers, meta_cls=meta_cls, fix_len=int(fix_len * sr), audio_mask=True) # train loss = Wave2WaveTrainer(model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path='', scheduler=scheduler, sr=sr).run() return { 'loss': loss, 'status': 'ok', }
def main(train_path: str, preprocessed_path: str, save_dir: str, save_prefix: str, model_name: str, pretrained_path: str = '', num_workers: int = 16, batch_size: int = 16, pitch_feature: str = 'phoneme', energy_feature: str = 'phoneme', pitch_min: float = 0., energy_min: float = 0., lr: float = 2e-4, weight_decay: float = 0.0001, betas=(0.9, 0.98), max_step: int = 400000, group_size: int = 4, save_interval: int = 10000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 5.0, milestones: Tuple[int] = None, gamma: float = 0.2, sr: int = 22050, seed: int = 2021, is_reference: bool = False): # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None dataset = Dataset(train_path, preprocessed_path, pitch_min=pitch_min, energy_min=energy_min, text_cleaners=['english_cleaners'], batch_size=batch_size, sort=True, drop_last=True, is_reference=is_reference) train_loader = DataLoader(dataset, batch_size=batch_size * group_size, shuffle=True, collate_fn=dataset.collate_fn, num_workers=num_workers) # train BaseTrainer(model, optimizer, train_loader, None, max_step=max_step, valid_max_step=1, save_interval=save_interval, log_interval=log_interval, pitch_feature=pitch_feature, energy_feature=energy_feature, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, sr=sr, scheduler=scheduler, seed=seed, is_reference=is_reference).run()
def main(meta_dir: str, save_dir: str, save_prefix: str = '', pretrained_path: str = '', model_name: str = 'refine_unet_base', batch_size: int = 128, num_workers: int = 16, fix_len: float = 2., lr: float = 3e-4, betas: Tuple[float] = (0.5, 0.9), weight_decay: float = 0.0, max_step: int = 100000, valid_max_step: int = 30, save_interval: int = 1000, log_interval: int = 50, grad_clip: float = 0.0, grad_norm: float = 30.0, is_audioset: bool = True, milestones: Tuple[int] = None, gamma: float = 0.1): # check args assert os.path.exists(meta_dir) # create model model = build_model(model_name).cuda() # multi-gpu if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # create optimizers optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if milestones: milestones = [int(x) for x in list(milestones)] scheduler = MultiStepLR(optimizer, milestones, gamma=gamma) else: scheduler = None # load dataset if is_audioset: dataset_func = get_datasets else: dataset_func = voice_bank.get_datasets train_loader, valid_loader = dataset_func( meta_dir, batch_size=batch_size, num_workers=num_workers, fix_len=int(fix_len * settings.SAMPLE_RATE), audio_mask=True) # train Wave2WaveTrainer(model, optimizer, train_loader, valid_loader, max_step=max_step, valid_max_step=min(valid_max_step, len(valid_loader)), save_interval=save_interval, log_interval=log_interval, save_dir=save_dir, save_prefix=save_prefix, grad_clip=grad_clip, grad_norm=grad_norm, pretrained_path=pretrained_path, scheduler=scheduler).run()