def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr,
                   total_steps):

    for p in optimiser.param_groups:
        p['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0.

        for i, (x, y, m, s_e) in enumerate(train_set, 1):

            x, m, y, spk_embd = x.cuda(), m.cuda(), y.cuda(), s_e.cuda()

            y_hat = model(x, m, spk_embd)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            running_loss += loss.item()

            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                model.checkpoint(paths.voc_checkpoints)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        model.save(paths.voc_latest_weights)
        model.log(paths.voc_log, msg)
        print(' ')
Exemple #2
0
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                y_hat = data_parallel_workaround(model, x, m)
            else:
                y_hat = model(x, m)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm)
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                ckpt_name = f'wave_step{k}K'
                save_checkpoint('voc',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc', paths, model, optimizer, is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, init_lr, final_lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    # for g in optimizer.param_groups: g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        adjust_learning_rate(optimizer, e, epochs, init_lr,
                             final_lr)  # 初始学习率与最终学习率-Begee
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(
                device)  # x/y: (Batch, sub_bands, T)

            #########################  MultiBand-WaveRNN   #########################
            if hp.voc_multiband:
                y0 = y[:, 0, :].squeeze(0).unsqueeze(
                    -1)  # y0/y1/y2/y3: (Batch, T, 1)
                y1 = y[:, 1, :].squeeze(0).unsqueeze(-1)
                y2 = y[:, 2, :].squeeze(0).unsqueeze(-1)
                y3 = y[:, 3, :].squeeze(0).unsqueeze(-1)

                y_hat = model(x, m)  # (Batch, T, num_classes, sub_bands)

                if model.mode == 'RAW':
                    y_hat0 = y_hat[:, :, :, 0].transpose(1, 2).unsqueeze(
                        -1)  # (Batch, num_classes, T, 1)
                    y_hat1 = y_hat[:, :, :, 1].transpose(1, 2).unsqueeze(-1)
                    y_hat2 = y_hat[:, :, :, 2].transpose(1, 2).unsqueeze(-1)
                    y_hat3 = y_hat[:, :, :, 3].transpose(1, 2).unsqueeze(-1)

                elif model.mode == 'MOL':
                    y0 = y0.float()
                    y1 = y1.float()
                    y2 = y2.float()
                    y3 = y3.float()

                loss = loss_func(y_hat0, y0) + loss_func(
                    y_hat1, y1) + loss_func(y_hat2, y2) + loss_func(
                        y_hat3, y3)

            #########################  MultiBand-WaveRNN   #########################

            optimizer.zero_grad()
            loss.backward()

            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm).cpu()
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                ckpt_name = f'wave_step{k}K'
                save_checkpoint('voc',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc', paths, model, optimizer, is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')
Exemple #4
0
def voc_train_loop(model, loss_func, optimizer, train_set, test_set, init_lr, final_lr, total_steps):

    total_iters = len(train_set)
    epochs = int((total_steps - model.get_step()) // total_iters + 1)

    if hp.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    torch.backends.cudnn.benchmark = True

    for e in range(1, epochs + 1):

        adjust_learning_rate(optimizer, e, epochs, init_lr, final_lr)

        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            y_hat = model(x, m)

            if model.mode == 'RAW' :
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL' :
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimizer.zero_grad()

            if hp.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

            optimizer.step()
            running_loss += loss.item()

            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0 :
                model.eval()
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
                            hp.voc_target, hp.voc_overlap, paths.voc_output)
                model.checkpoint(paths.voc_checkpoints)
                model.train()

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        model.save(paths.voc_latest_weights)
        model.log(paths.voc_log, msg)
Exemple #5
0
def voc_train_loop(model, loss_func, optimiser, train_set, eval_set, test_set,
                   lr, total_steps, device, hp):

    for p in optimiser.param_groups:
        p['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1
    trg = None
    patience = hp.patience
    min_val_loss = np.inf

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0.
        running_pase_reg_loss = 0.
        running_nll_loss = 0.
        pase_reg_loss = None

        for i, (m, xm, x, y, neigh) in enumerate(train_set, 1):
            m, xm, x, y, neigh = m.to(device), xm.to(device), x.to(
                device), y.to(device), neigh.to(device)

            if hp.pase_cntnt is not None:
                if hp.pase_cntnt_ft:
                    m_clean = m
                    m = hp.pase_cntnt(xm.unsqueeze(1))
                    if hp.pase_lambda > 0:
                        # use an MSE loss weighted with pase_lamda
                        # that tights the distorted PASE output
                        # to the clean PASE soft-labels (loaded in m)
                        pase_reg_loss = hp.pase_lambda * F.mse_loss(m, m_clean)
                else:
                    with torch.no_grad():
                        m = hp.pase_cntnt(xm.unsqueeze(1))
            if hp.conversion:
                if hp.pase_id is not None:
                    if hp.pase_id_ft:
                        trg = hp.pase_id(neigh.unsqueeze(1))
                    else:
                        with torch.no_grad():
                            # speed up discarding grad info to backtrack the graph
                            trg = hp.pase_id(neigh.unsqueeze(1))

            y_hat = model(x, m, trg)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            running_nll_loss += loss.item()

            optimiser.zero_grad()
            if pase_reg_loss is not None:
                total_loss = loss + pase_reg_loss
                running_pase_reg_loss += pase_reg_loss.item()
                pase_reg_avg_loss = running_pase_reg_loss / i
            else:
                total_loss = loss
            total_loss.backward()
            optimiser.step()
            running_loss += total_loss.item()

            speed = i / (time.time() - start)
            nll_avg_loss = running_nll_loss / i
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_write_every == 0:
                hp.writer.add_scalar('train/nll', avg_loss, step)
                if pase_reg_loss is not None:
                    hp.writer.add_scalar('train/pase_reg_loss',
                                         pase_reg_avg_loss, step)

            if step % hp.voc_checkpoint_every == 0:
                if eval_set is not None:
                    print('Validating')
                    # validate the model
                    val_loss = voc_eval_loop(model, loss_func, eval_set,
                                             device)
                    if val_loss <= min_val_loss:
                        patience = hp.patience
                        print('Val loss improved: {:.4f} -> '
                              '{:.4f}'.format(min_val_loss, val_loss))
                        min_val_loss = val_loss
                    else:
                        patience -= 1
                        print('Val loss did not improve. Patience '
                              '{}/{}'.format(patience, hp.patience))
                        if patience == 0:
                            print('Out of patience. Breaking the loop')
                            break
                    # set to train mode again
                    model.train()
                # generate some test samples
                gen_testset(model,
                            test_set,
                            hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched,
                            hp.voc_target,
                            hp.voc_overlap,
                            paths.voc_output,
                            hp=hp,
                            device=device)
                model.checkpoint(paths.voc_checkpoints)
                if hp.pase_cntnt is not None and hp.pase_cntnt_ft:
                    hp.pase_cntnt.train()
                    hp.pase_cntnt.save(paths.voc_checkpoints, step)
                if hp.conversion:
                    if hp.pase_id is not None and hp.pase_id_ft:
                        hp.pase_id.train()
                        hp.pase_id.save(paths.voc_checkpoints, step)

            if pase_reg_loss is None:
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | NLLoss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            else:
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Total Loss: {avg_loss:.4f} | NLLoss: {avg_nll_loss:.4f} | PASE reg loss: {pase_reg_avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        model.save(paths.voc_latest_weights)
        model.log(paths.voc_log, msg)
        print(' ')