def main(args): train_loader, val_loader, collate_fn = prepare_dataloaders(hp) model = Model(hp).cuda() optimizer = torch.optim.Adamax(model.parameters(), lr=hp.lr) writer = get_writer(hp.output_directory, args.logdir) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") iteration = 0 model.train() print(f"Training Start!!! ({args.logdir})") while iteration < (hp.train_steps): for i, batch in enumerate(train_loader): text_padded, text_lengths, mel_padded, mel_lengths = [ x.cuda() for x in batch ] recon_loss, kl_loss, duration_loss, align_loss = model(text_padded, mel_padded, text_lengths, mel_lengths) alpha=min(1, iteration/hp.kl_warmup_steps) with amp.scale_loss((recon_loss + alpha*kl_loss + duration_loss + align_loss), optimizer) as scaled_loss: scaled_loss.backward() iteration += 1 lr_scheduling(optimizer, iteration) nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) optimizer.step() model.zero_grad() writer.add_scalar('train_recon_loss', recon_loss, global_step=iteration) writer.add_scalar('train_kl_loss', kl_loss, global_step=iteration) writer.add_scalar('train_duration_loss', duration_loss, global_step=iteration) writer.add_scalar('train_align_loss', align_loss, global_step=iteration) if iteration % (hp.iters_per_validation) == 0: validate(model, val_loader, iteration, writer) if iteration % (hp.iters_per_checkpoint) == 0: save_checkpoint(model, optimizer, hp.lr, iteration, filepath=f'{hp.output_directory}/{args.logdir}') if iteration == (hp.train_steps): break
def main(args): train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=args.stage) if args.stage!=0: checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}" state_dict = {} for k, v in torch.load(checkpoint_path)['state_dict'].items(): state_dict[k[7:]]=v model = Model(hparams).cuda() model.load_state_dict(state_dict) model = nn.DataParallel(model).cuda() else: model = nn.DataParallel(Model(hparams)).cuda() criterion = MDNLoss() writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage{args.stage}') optimizer = torch.optim.Adam(model.parameters(), lr=hparams.lr, betas=(0.9, 0.98), eps=1e-09) iteration, loss = 0, 0 model.train() print(f'Stage{args.stage} Start!!! ({str(datetime.now())})') while True: for i, batch in enumerate(train_loader): if args.stage==0: text_padded, mel_padded, text_lengths, mel_lengths = [ reorder_batch(x, hparams.n_gpus).cuda() for x in batch ] align_padded=None else: text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [ reorder_batch(x, hparams.n_gpus).cuda() for x in batch ] sub_loss = model(text_padded, mel_padded, align_padded, text_lengths, mel_lengths, criterion, stage=args.stage) sub_loss = sub_loss.mean()/hparams.accumulation sub_loss.backward() loss = loss+sub_loss.item() iteration += 1 if iteration%hparams.accumulation == 0: lr_scheduling(optimizer, iteration//hparams.accumulation) nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) optimizer.step() model.zero_grad() writer.add_scalar('Train loss', loss, iteration//hparams.accumulation) loss=0 if iteration%(hparams.iters_per_validation*hparams.accumulation)==0: validate(model, criterion, val_loader, iteration, writer, args.stage) if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0: save_checkpoint(model, optimizer, hparams.lr, iteration//hparams.accumulation, filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}') if iteration==(hparams.train_steps[args.stage]*hparams.accumulation): break if iteration==(hparams.train_steps[args.stage]*hparams.accumulation): break print(f'Stage{args.stage} End!!! ({str(datetime.now())})')
def main(args): train_loader, val_loader, collate_fn = prepare_dataloaders( hparams, stage=args.stage) initial_iteration = None if args.stage != 0: checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}" if not os.path.isfile(checkpoint_path): print(f'{checkpoint_path} does not exist') checkpoint_path = sorted( glob(f"training_log/aligntts/stage{args.stage-1}/checkpoint_*") )[-1] print(f'Loading {checkpoint_path} instead') state_dict = {} for k, v in torch.load(checkpoint_path)['state_dict'].items(): state_dict[k[7:]] = v model = Model(hparams).cuda() model.load_state_dict(state_dict) model = nn.DataParallel(model).cuda() else: if args.pre_trained_model != '': if not os.path.isfile(args.pre_trained_model): print(f'{args.pre_trained_model} does not exist') state_dict = {} for k, v in torch.load( args.pre_trained_model)['state_dict'].items(): state_dict[k[7:]] = v initial_iteration = torch.load(args.pre_trained_model)['iteration'] model = Model(hparams).cuda() model.load_state_dict(state_dict) model = nn.DataParallel(model).cuda() else: model = nn.DataParallel(Model(hparams)).cuda() criterion = MDNLoss() writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage{args.stage}') optimizer = torch.optim.Adam(model.parameters(), lr=hparams.lr, betas=(0.9, 0.98), eps=1e-09) iteration, loss = 0, 0 if initial_iteration is not None: iteration = initial_iteration model.train() print(f'Stage{args.stage} Start!!! ({str(datetime.now())})') while True: for i, batch in enumerate(train_loader): if args.stage == 0: text_padded, mel_padded, text_lengths, mel_lengths = [ reorder_batch(x, hparams.n_gpus).cuda() for x in batch ] align_padded = None else: text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [ reorder_batch(x, hparams.n_gpus).cuda() for x in batch ] sub_loss = model(text_padded, mel_padded, align_padded, text_lengths, mel_lengths, criterion, stage=args.stage, log_viterbi=args.log_viterbi, cpu_viterbi=args.cpu_viterbi) sub_loss = sub_loss.mean() / hparams.accumulation sub_loss.backward() loss = loss + sub_loss.item() iteration += 1 if iteration % 100 == 0: print( f'[{str(datetime.now())}] Stage {args.stage} Iter {iteration:<6d} Loss {loss:<8.6f}' ) if iteration % hparams.accumulation == 0: lr_scheduling(optimizer, iteration // hparams.accumulation) nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) optimizer.step() model.zero_grad() writer.add_scalar('Train loss', loss, iteration // hparams.accumulation) writer.add_scalar('Learning rate', get_lr(optimizer), iteration // hparams.accumulation) loss = 0 if iteration % (hparams.iters_per_validation * hparams.accumulation) == 0: validate(model, criterion, val_loader, iteration, writer, args.stage) if iteration % (hparams.iters_per_checkpoint * hparams.accumulation) == 0: save_checkpoint( model, optimizer, hparams.lr, iteration // hparams.accumulation, filepath= f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}' ) if iteration == (hparams.train_steps[args.stage] * hparams.accumulation): break if iteration == (hparams.train_steps[args.stage] * hparams.accumulation): break print(f'Stage{args.stage} End!!! ({str(datetime.now())})')