Example #1
0
def main():
    train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
    model = nn.DataParallel(GraphTTS(hparams)).cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hparams.lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-09)
    criterion = TransformerLoss()
    writer = get_writer(hparams.output_directory, hparams.log_directory)

    iteration, loss = 0, 0
    model.train()
    print("Training Start!!!")
    while iteration < (hparams.train_steps * hparams.accumulation):
        for i, batch in enumerate(train_loader):
            text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
                reorder_batch(x, hparams.n_gpus).cuda() for x in batch
            ]

            mel_loss, bce_loss, guide_loss = model(text_padded, adj_padded,
                                                   mel_padded, gate_padded,
                                                   text_lengths, mel_lengths,
                                                   criterion)

            mel_loss, bce_loss, guide_loss = [
                torch.mean(x) for x in [mel_loss, bce_loss, guide_loss]
            ]
            sub_loss = (mel_loss + bce_loss +
                        guide_loss) / hparams.accumulation
            sub_loss.backward()
            loss = loss + sub_loss.item()

            iteration += 1
            if iteration % hparams.accumulation == 0:
                lr_scheduling(optimizer, iteration // hparams.accumulation)
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hparams.grad_clip_thresh)
                optimizer.step()
                model.zero_grad()
                writer.add_losses(mel_loss.item(), bce_loss.item(),
                                  guide_loss.item(),
                                  iteration // hparams.accumulation, 'Train')
                loss = 0

            if iteration % (hparams.iters_per_validation *
                            hparams.accumulation) == 0:
                validate(model, criterion, val_loader, iteration, writer)

            if iteration % (hparams.iters_per_checkpoint *
                            hparams.accumulation) == 0:
                save_checkpoint(
                    model,
                    optimizer,
                    hparams.lr,
                    iteration // hparams.accumulation,
                    filepath=
                    f'{hparams.output_directory}/{hparams.log_directory}')

            if iteration == (hparams.train_steps * hparams.accumulation):
                break
Example #2
0
def main():
    train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
    model = nn.DataParallel(Model(hparams)).cuda()

    if hparams.pretrained_embedding == True:
        state_dict = torch.load(
            f'{hparams.teacher_dir}/checkpoint_200000')['state_dict']
        for k, v in state_dict.items():
            if k == 'alpha1':
                model.alpha1.data = v

            if k == 'alpha2':
                model.alpha2.data = v

            if 'Embedding' in k:
                setattr(model, k, v)

            if 'Encoder' in k:
                setattr(model, k, v)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hparams.lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-09)
    criterion = TransformerLoss()
    writer = get_writer(hparams.output_directory, hparams.log_directory)

    iteration, loss = 0, 0
    model.train()
    print("Training Start!!!")
    while iteration < (hparams.train_steps * hparams.accumulation):
        for i, batch in enumerate(train_loader):
            text_padded, text_lengths, mel_padded, mel_lengths, align_padded = [
                reorder_batch(x, hparams.n_gpus).cuda() for x in batch
            ]
            mel_loss, duration_loss = model(text_padded, mel_padded,
                                            align_padded, text_lengths,
                                            mel_lengths, criterion)

            mel_loss, duration_loss = [
                torch.mean(x) for x in [mel_loss, duration_loss]
            ]
            sub_loss = (mel_loss + duration_loss) / hparams.accumulation
            sub_loss.backward()
            loss = loss + sub_loss.item()

            iteration += 1
            if iteration % hparams.accumulation == 0:
                lr_scheduling(optimizer, iteration // hparams.accumulation)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hparams.grad_clip_thresh)
                optimizer.step()
                model.zero_grad()
                writer.add_scalar('mel_loss',
                                  mel_loss.item(),
                                  global_step=iteration //
                                  hparams.accumulation)
                writer.add_scalar('duration_loss',
                                  duration_loss.item(),
                                  global_step=iteration //
                                  hparams.accumulation)
                loss = 0

            if iteration % (hparams.iters_per_validation *
                            hparams.accumulation) == 0:
                validate(model, criterion, val_loader, iteration, writer)

            if iteration % (hparams.iters_per_checkpoint *
                            hparams.accumulation) == 0:
                save_checkpoint(
                    model,
                    optimizer,
                    hparams.lr,
                    iteration // hparams.accumulation,
                    filepath=
                    f'{hparams.output_directory}/{hparams.log_directory}')

            if iteration == (hparams.train_steps * hparams.accumulation):
                break
Example #3
0
def main(args):
    train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=args.stage)

    if args.stage!=0:
        checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}"
        state_dict = {}
        for k, v in torch.load(checkpoint_path)['state_dict'].items():
            state_dict[k[7:]]=v

        model = Model(hparams).cuda()
        model.load_state_dict(state_dict)
        model = nn.DataParallel(model).cuda()
    else:
        model = nn.DataParallel(Model(hparams)).cuda()

    criterion = MDNLoss()
    writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage{args.stage}')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hparams.lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-09)
    iteration, loss = 0, 0
    model.train()

    print(f'Stage{args.stage} Start!!! ({str(datetime.now())})')
    while True:
        for i, batch in enumerate(train_loader):
            if args.stage==0:
                text_padded, mel_padded, text_lengths, mel_lengths = [
                    reorder_batch(x, hparams.n_gpus).cuda() for x in batch
                ]
                align_padded=None
            else:
                text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
                    reorder_batch(x, hparams.n_gpus).cuda() for x in batch
                ]

            sub_loss = model(text_padded,
                             mel_padded,
                             align_padded,
                             text_lengths,
                             mel_lengths,
                             criterion,
                             stage=args.stage)
            sub_loss = sub_loss.mean()/hparams.accumulation
            sub_loss.backward()
            loss = loss+sub_loss.item()
            iteration += 1

            if iteration%hparams.accumulation == 0:
                lr_scheduling(optimizer, iteration//hparams.accumulation)
                nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
                optimizer.step()
                model.zero_grad()
                writer.add_scalar('Train loss', loss, iteration//hparams.accumulation)
                loss=0

            if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
                validate(model, criterion, val_loader, iteration, writer, args.stage)

            if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
                save_checkpoint(model,
                                optimizer,
                                hparams.lr,
                                iteration//hparams.accumulation,
                                filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}')

            if iteration==(hparams.train_steps[args.stage]*hparams.accumulation):
                break

        if iteration==(hparams.train_steps[args.stage]*hparams.accumulation):
            break
            
    print(f'Stage{args.stage} End!!! ({str(datetime.now())})')
Example #4
0
def main(args):
    train_loader, val_loader, collate_fn = prepare_dataloaders(
        hparams, stage=args.stage)
    initial_iteration = None
    if args.stage != 0 and args.pre_trained_model != '':
        checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}"

        if not os.path.isfile(checkpoint_path):
            print(f'{checkpoint_path} does not exist')
            checkpoint_path = sorted(
                glob(f"training_log/aligntts/stage{args.stage-1}/checkpoint_*")
            )[-1]
            print(f'Loading {checkpoint_path} instead')

        state_dict = {}
        for k, v in torch.load(checkpoint_path)['state_dict'].items():
            state_dict[k[7:]] = v

        model = Model(hparams).cuda()
        model.load_state_dict(state_dict)
        model = nn.DataParallel(model).cuda()
    elif args.stage != 0:
        model = nn.DataParallel(Model(hparams)).cuda()
    else:
        if args.pre_trained_model != '':
            if not os.path.isfile(args.pre_trained_model):
                print(f'{args.pre_trained_model} does not exist')

            state_dict = {}
            for k, v in torch.load(
                    args.pre_trained_model)['state_dict'].items():
                state_dict[k[7:]] = v
            initial_iteration = torch.load(args.pre_trained_model)['iteration']
            model = Model(hparams).cuda()
            model.load_state_dict(state_dict)
            model = nn.DataParallel(model).cuda()
        else:

            model = nn.DataParallel(Model(hparams)).cuda()

    criterion = MDNDNNLoss()
    writer = get_writer(hparams.output_directory,
                        f'{hparams.log_directory}/stage{args.stage}')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hparams.lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-09)
    iteration, loss = 0, 0
    if initial_iteration is not None:
        iteration = initial_iteration

    model.train()

    print(f'Stage{args.stage} Start!!! ({str(datetime.now())})')
    while True:
        for i, batch in enumerate(train_loader):
            if args.stage == 0:
                text_padded, mel_padded, text_lengths, mel_lengths = [
                    reorder_batch(x, hparams.n_gpus).cuda() for x in batch
                ]
                align_padded = None
            else:
                text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
                    reorder_batch(x, hparams.n_gpus).cuda() for x in batch
                ]

            sub_loss = model(text_padded,
                             mel_padded,
                             align_padded,
                             text_lengths,
                             mel_lengths,
                             criterion,
                             stage=args.stage,
                             log_viterbi=args.log_viterbi,
                             cpu_viterbi=args.cpu_viterbi)
            sub_loss = sub_loss.mean() / hparams.accumulation
            sub_loss.backward()
            loss = loss + sub_loss.item()
            iteration += 1
            if iteration % 100 == 0:
                print(
                    f'[{str(datetime.now())}] Stage {args.stage} Iter {iteration:<6d} Loss {loss:<8.6f}'
                )

            if iteration % hparams.accumulation == 0:
                # lr_scheduling(optimizer, iteration//hparams.accumulation)
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hparams.grad_clip_thresh)
                optimizer.step()
                model.zero_grad()
                writer.add_scalar('Train loss', loss,
                                  iteration // hparams.accumulation)
                writer.add_scalar('Learning rate', get_lr(optimizer),
                                  iteration // hparams.accumulation)
                loss = 0

            # validate(model, criterion, val_loader, iteration, writer, args.stage)
            if iteration % (hparams.iters_per_validation *
                            hparams.accumulation) == 0:
                validate(model, criterion, val_loader, iteration, writer,
                         args.stage)

            if iteration % (hparams.iters_per_checkpoint *
                            hparams.accumulation) == 0:
                save_checkpoint(
                    model,
                    optimizer,
                    hparams.lr,
                    iteration // hparams.accumulation,
                    filepath=
                    f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}'
                )

            if iteration == (hparams.train_steps[args.stage] *
                             hparams.accumulation):
                break

        if iteration == (hparams.train_steps[args.stage] *
                         hparams.accumulation):
            break

    print(f'Stage{args.stage} End!!! ({str(datetime.now())})')