예제 #1
0
def train(epoch, model, optimizer, scheduler, ema):
    global global_step
    epoch_loss = 0.
    running_loss = [0., 0., 0., 0., 0.]
    model.train()
    start_time = time.time()
    display_step = 100
    for batch_idx, (x, _, c, _) in enumerate(train_loader):
        scheduler.step()
        global_step += 1

        x, c = x.to(device), c.to(device)

        optimizer.zero_grad()
        x_rec, x_prior, loss_rec, loss_kl = model(x, c)

        stft_rec, stft_rec_log = stft(x_rec[:, 0, 1:])
        stft_truth, stft_truth_log = stft(x[:, 0, 1:])
        stft_prior, stft_prior_log = stft(x_prior[:, 0, 1:])

        loss_frame_rec = criterion_l2(stft_rec, stft_truth) + criterion_l1(
            stft_rec_log, stft_truth_log)
        loss_frame_prior = criterion_l2(stft_prior, stft_truth) + criterion_l1(
            stft_prior_log, stft_truth_log)

        # KL annealing coefficient
        alpha = 1 / (1 + np.exp(-5e-5 * (global_step - 5e+5)))
        loss_rec, loss_kl = loss_rec.mean(), loss_kl.mean()
        loss_tot = loss_rec + loss_kl * alpha + loss_frame_rec + loss_frame_prior
        loss_tot.backward()

        nn.utils.clip_grad_norm_(model.parameters(), 10.)
        optimizer.step()
        if ema is not None:
            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

        running_loss[0] += loss_tot.item() / display_step
        running_loss[1] += loss_rec.item() / display_step
        running_loss[2] += loss_kl.item() / display_step
        running_loss[3] += loss_frame_rec.item() / display_step
        running_loss[4] += loss_frame_prior.item() / display_step
        epoch_loss += loss_tot.item()
        if (batch_idx + 1) % display_step == 0:
            end_time = time.time()
            print(
                'Global Step : {}, [{}, {}] [Total Loss, Rec Loss, KL Loss, STFT Recon, STFT Prior)] : {}'
                .format(global_step, epoch, batch_idx + 1,
                        np.array(running_loss)))
            print('{} Step Time : {}'.format(display_step,
                                             end_time - start_time))
            start_time = time.time()
            running_loss = [0., 0., 0., 0., 0.]
        del loss_tot, loss_frame_rec, loss_frame_prior, loss_kl, loss_rec, x, c, x_rec, x_prior
        del stft_rec, stft_truth, stft_prior, stft_truth_log
    print('{} Epoch Training Loss : {:.4f}'.format(
        epoch, epoch_loss / (len(train_loader))))
    return epoch_loss / len(train_loader)
예제 #2
0
def evaluate(model_t, model_s, ema=None):
    if ema is not None:
        model_s_ema = clone_as_averaged_model(model_s, ema)
    model_t.eval()
    model_s_ema.eval()
    running_loss = [0., 0., 0., 0.]
    epoch_loss = 0.

    display_step = 100
    for batch_idx, (x, y, c, _) in enumerate(test_loader):
        x, y, c = x.to(device), y.to(device), c.to(device)

        q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
        z = q_0.sample()
        c_up = model_t.upsample(c)

        x_student, mu_s, logs_s = model_s_ema(z, c_up)

        mu_logs_t = model_t(x_student, c)

        if args.KL_type == 'pq':
            loss_t, loss_KL, loss_reg = criterion_t(mu_logs_t[:, 0:1, :-1],
                                                    mu_logs_t[:, 1:, :-1],
                                                    mu_s, logs_s)
        elif args.KL_type == 'qp':
            loss_t, loss_KL, loss_reg = criterion_t(mu_s, logs_s,
                                                    mu_logs_t[:, 0:1, :-1],
                                                    mu_logs_t[:, 1:, :-1])

        stft_student = stft(x_student[:, 0, 1:], scale='linear')
        stft_truth = stft(x[:, 0, 1:], scale='linear')

        loss_frame = criterion_frame(stft_student, stft_truth.detach())

        loss_tot = loss_t + loss_frame

        running_loss[0] += loss_tot.item() / display_step
        running_loss[1] += loss_KL.item() / display_step
        running_loss[2] += loss_reg.item() / display_step
        running_loss[3] += loss_frame.item() / display_step
        epoch_loss += loss_tot.item()

        if (batch_idx + 1) % display_step == 0:
            print('{} [Total, KL, Reg, Frame Loss] : {}'.format(
                batch_idx + 1, np.array(running_loss)))
            running_loss = [0., 0., 0., 0.]
        del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z
        del x_student, mu_s, logs_s, mu_logs_t
    epoch_loss /= len(test_loader)
    print('Evaluation Loss : {:.4f}'.format(epoch_loss))
    del model_s_ema
    return epoch_loss
예제 #3
0
def evaluate(model, ema=None):
    if ema is not None:
        model_ema = clone_as_averaged_model(model, ema)
    model_ema.eval()
    running_loss = [0., 0., 0., 0., 0.]
    epoch_loss = 0.

    display_step = 100
    for batch_idx, (x, _, c, _) in enumerate(test_loader):
        x, c = x.to(device), c.to(device)

        x_rec, x_prior, loss_rec, loss_kl = model(x, c)

        stft_rec, stft_rec_log = stft(x_rec[:, 0, 1:])
        stft_truth, stft_truth_log = stft(x[:, 0, 1:])
        stft_prior, stft_prior_log = stft(x_prior[:, 0, 1:])

        loss_frame_rec = criterion_l2(stft_rec, stft_truth) + criterion_l1(
            stft_rec_log, stft_truth_log)
        loss_frame_prior = criterion_l2(stft_prior, stft_truth) + criterion_l1(
            stft_prior_log, stft_truth_log)

        # KL annealing coefficient
        alpha = 1 / (1 + np.exp(-5e-5 * (global_step - 1e+6)))
        loss_rec, loss_kl = loss_rec.mean(), loss_kl.mean()
        loss_tot = loss_rec + loss_kl * alpha + loss_frame_rec + loss_frame_prior

        if ema is not None:
            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

        running_loss[0] += loss_tot.item() / display_step
        running_loss[1] += loss_rec.item() / display_step
        running_loss[2] += loss_kl.item() / display_step
        running_loss[3] += loss_frame_rec.item() / display_step
        running_loss[4] += loss_frame_prior.item() / display_step
        epoch_loss += loss_tot.item()

        if (batch_idx + 1) % display_step == 0:
            print(
                'Global Step : {}, [{}, {}] [Total Loss, Rec Loss, KL Loss, STFT Recon, STFT Prior)] : {}'
                .format(global_step, epoch, batch_idx + 1,
                        np.array(running_loss)))
            running_loss = [0., 0., 0., 0., 0.]
        del loss_tot, loss_frame_rec, loss_frame_prior, loss_kl, loss_rec, x, c, x_rec, x_prior
        del stft_rec, stft_truth, stft_prior, stft_truth_log
    epoch_loss /= len(test_loader)
    print('Evaluation Loss : {:.4f}'.format(epoch_loss))
    del model_ema
    return epoch_loss
예제 #4
0
def train(epoch, model_t, model_s, optimizer, ema):
    global global_step
    epoch_loss = 0.0
    running_loss = [0.0, 0.0, 0.0, 0.0]
    model_t.eval()
    model_s.train()
    start_time = time.time()
    display_step = 100
    for batch_idx, (x, y, c, _) in enumerate(train_loader):
        global_step += 1
        if global_step == 200000:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
                state['learning_rate'] = param_group['lr']
        if global_step == 400000:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
                state['learning_rate'] = param_group['lr']
        if global_step == 600000:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
                state['learning_rate'] = param_group['lr']

        x, y, c = x.to(device), y.to(device), c.to(device)

        q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
        z = q_0.sample()

        optimizer.zero_grad()
        c_up = model_t.upsample(c)
        x_student, mu_s, logs_s = model_s(
            z, c_up)  # q_T ~ N(mu_tot, logs_tot.exp_())

        mu_logs_t = model_t(x_student, c)

        if args.KL_type == 'pq':
            loss_t, loss_KL, loss_reg = criterion_t(mu_logs_t[:, 0:1, :-1],
                                                    mu_logs_t[:, 1:, :-1],
                                                    mu_s, logs_s)
        elif args.KL_type == 'qp':
            loss_t, loss_KL, loss_reg = criterion_t(mu_s, logs_s,
                                                    mu_logs_t[:, 0:1, :-1],
                                                    mu_logs_t[:, 1:, :-1])

        stft_student = stft(x_student[:, 0, 1:], scale='linear')
        stft_truth = stft(x[:, 0, 1:], scale='linear')
        loss_frame = criterion_frame(stft_student, stft_truth)
        loss_tot = loss_t + loss_frame
        loss_tot.backward()

        nn.utils.clip_grad_norm_(model_s.parameters(), 10.)
        optimizer.step()
        if ema is not None:
            for name, param in model_s.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

        running_loss[0] += loss_tot.item() / display_step
        running_loss[1] += loss_KL.item() / display_step
        running_loss[2] += loss_reg.item() / display_step
        running_loss[3] += loss_frame.item() / display_step
        epoch_loss += loss_tot.item()
        if (batch_idx + 1) % display_step == 0:
            end_time = time.time()
            print(
                'Global Step : {}, [{}, {}] [Total Loss, KL Loss, Reg Loss, Frame Loss] : {}'
                .format(global_step, epoch, batch_idx + 1,
                        np.array(running_loss)))
            print('{} Step Time : {}'.format(display_step,
                                             end_time - start_time))
            start_time = time.time()
            running_loss = [0.0, 0.0, 0.0, 0.0]
        del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z
        del x_student, mu_s, logs_s, mu_logs_t
    print('{} Epoch Training Loss : {:.4f}'.format(
        epoch, epoch_loss / (len(train_loader))))
    return epoch_loss / len(train_loader)