Exemple #1
0
    def forward(self, y_hat, mu_q, scale_q, mask, sample_T=32):
        if hparams.output_type == 'Gaussian':
            # teacher p,student q
            mu_p, scale_p = y_hat[:, :1, :], torch.exp(y_hat[:, 1:, :])
            loss = torch.log(scale_p / scale_q) + (scale_q ** 2 - scale_p ** 2 + (mu_q - mu_p) ** 2) / ( 2 * scale_p ** 2)
            # loss += torch.log(scale_q / scale_p) + (scale_p ** 2 - scale_q ** 2 + (mu_q - mu_p) ** 2) / ( 2 * scale_q ** 2)
            # loss /= 2
            loss += self.lambda_*(torch.log(scale_p)-torch.log(scale_q))**2
            kl_loss = torch.sum(loss[:,:,:-1] * mask.permute(0,2,1)) / mask.sum()
            return kl_loss
        elif hparams.output_type == "MOL":
            h_pt_ps = 0
            for i in range(sample_T):
                u = torch.zeros(mu_q.size()).uniform_(1e-5, 1 - 1e-5)
                if use_cuda:
                    u = u.cuda()
                z = torch.log(u) - torch.log(1 - u)
                student_predict = mu_q + z * scale_q
                assert student_predict.requires_grad is True

                student_predict = student_predict.permute(0, 2, 1)
                teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
                h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()

            # compute h_ps
            a = scale_q.permute(0, 2, 1)
            h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum())

            # compute kl loss
            cross_entropy = h_pt_ps / sample_T
            kl_loss = cross_entropy - h_ps
            return kl_loss
Exemple #2
0
 def __init__(self, hparams):
     super(DiscretizedMixturelogisticLoss, self).__init__()
     self.quantize_channels = hparams.quantize_channels
     self.log_scale_min = hparams.log_scale_min
     self.discretized_mix_logistic_loss = discretized_mix_logistic_loss(num_classes=hparams.quantize_channels,
                                                                        log_scale_min=hparams.log_scale_min,
                                                                        reduce=False)
     self.reduce_sum_op = P.ReduceSum()
     self.reduce_mean_op = P.ReduceMean()
def test_mixture():
    np.random.seed(1234)

    x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None)
    assert sr == 16000

    T = len(x)
    x = x.reshape(1, T, 1)
    y = Variable(torch.from_numpy(x)).float()
    y_hat = Variable(torch.rand(1, 30, T)).float()

    print(y.shape, y_hat.shape)

    loss = discretized_mix_logistic_loss(y_hat, y)
    print(loss)

    loss = discretized_mix_logistic_loss(y_hat, y, reduce=False)
    print(loss.size(), y.size())
    assert loss.size() == y.size()

    y = sample_from_discretized_mix_logistic(y_hat)
    print(y.shape)
Exemple #4
0
    def forward(self, input, target, lengths=None, mask=None, max_len=None):
        if lengths is None and mask is None:
            raise RuntimeError("Should provide either lengths or mask")

        # (B, T, 1)
        if mask is None:
            mask = sequence_mask(lengths, max_len).unsqueeze(-1)

        # (B, T, 1)
        mask_ = mask.expand_as(target)

        losses = discretized_mix_logistic_loss(
            input, target, num_classes=hparams.quantize_channels,
            log_scale_min=hparams.log_scale_min, reduce=False)
        assert losses.size() == target.size()
        return ((losses * mask_).sum()) / mask_.sum()
Exemple #5
0
def __train_step(phase,
                 epoch,
                 global_step,
                 global_test_step,
                 teacher,
                 student,
                 optimizer,
                 writer,
                 x,
                 y,
                 c,
                 g,
                 input_lengths,
                 checkpoint_dir,
                 eval_dir=None,
                 do_eval=False,
                 ema=None):
    sanity_check(teacher, c, g)
    sanity_check(student, c, g)

    # x : (B, C, T)
    # y : (B, T, 1)
    # c : (B, C, T)
    # g : (B,)
    train = (phase == "train")
    clip_thresh = hparams.clip_thresh
    if train:
        teacher.eval()  # set teacher as eval mode
        student.train()
        step = global_step
    else:
        student.eval()
        step = global_test_step

    # ---------------------- the parallel wavenet use constant learning rate = 0.0002
    # Learning rate schedule
    current_lr = hparams.initial_learning_rate
    if train and hparams.lr_schedule is not None:
        lr_schedule_f = getattr(lrschedule, hparams.lr_schedule)
        current_lr = lr_schedule_f(hparams.initial_learning_rate, step,
                                   **hparams.lr_schedule_kwargs)
        if gpu_count > 1:
            for param_group in optimizer.module.param_groups:
                param_group['lr'] = current_lr
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
    optimizer.zero_grad()
    # Prepare data
    x, y = Variable(x), Variable(y, requires_grad=False)
    c = Variable(c) if c is not None else None
    g = Variable(g) if g is not None else None
    input_lengths = Variable(input_lengths)
    if use_cuda:
        x, y = x.cuda(), y.cuda()
        input_lengths = input_lengths.cuda()
        c = c.cuda() if c is not None else None
        g = g.cuda() if g is not None else None

    # (B, T, 1)
    mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1)
    mask = mask[:, 1:, :]
    # apply the student model with stacked iaf layers and return mu,scale
    # u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False)
    # z = torch.log(u) - torch.log(1 - u)
    u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5),
                 requires_grad=False).cuda()
    z = torch.log(u) - torch.log(1 - u)
    predict, mu, scale = student(z, c=c, g=g, softmax=False)
    m, s = mu, scale
    # mu, scale = to_numpy(mu), to_numpy(scale)
    # TODO sample times, change to 300 or 400
    sample_T, kl_loss_sum = 16, 0
    power_loss_sum = 0
    y_hat = teacher(predict, c=c,
                    g=g)  # y_hat: (B x C x T) teacher: 10-mixture-logistic
    h_pt_ps = 0
    # TODO add some constrain on scale ,we want it to be small?
    for i in range(sample_T):
        # https://en.wikipedia.org/wiki/Logistic_distribution
        u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5),
                     requires_grad=False).cuda()
        z = torch.log(u) - torch.log(1 - u)
        student_predict = m + s * z  # predicted wave
        # student_predict.clamp(-0.99, 0.99)
        student_predict = student_predict.permute(0, 2, 1)
        _, teacher_log_p = discretized_mix_logistic_loss(
            y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
        h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()
        student_predict = student_predict.permute(0, 2, 1)
        power_loss_sum += get_power_loss_torch(student_predict,
                                               x,
                                               n_fft=512,
                                               hop_length=128)
        power_loss_sum += get_power_loss_torch(student_predict,
                                               x,
                                               n_fft=256,
                                               hop_length=64)
        power_loss_sum += get_power_loss_torch(student_predict,
                                               x,
                                               n_fft=2048,
                                               hop_length=512)
        power_loss_sum += get_power_loss_torch(student_predict,
                                               x,
                                               n_fft=1024,
                                               hop_length=256)
        power_loss_sum += get_power_loss_torch(student_predict,
                                               x,
                                               n_fft=128,
                                               hop_length=32)
    a = s.permute(0, 2, 1)
    h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum())
    cross_entropy = h_pt_ps / (sample_T)
    kl_loss = cross_entropy - 2 * h_ps
    # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=64)
    # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=128)
    # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=256)
    # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=512)
    power_loss = power_loss_sum / (5 * sample_T)
    loss = kl_loss + power_loss
    if step > 0 and step % 20 == 0:
        print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'.
              format(to_numpy(power_loss), np.mean(to_numpy(s)),
                     np.mean(to_numpy(m)), to_numpy(kl_loss), to_numpy(loss)))
    if train and step > 0 and step % hparams.checkpoint_interval == 0:
        save_states(step,
                    writer,
                    y_hat=y_hat,
                    y=y,
                    y_student=predict,
                    input_lengths=input_lengths,
                    mu=m,
                    checkpoint_dir=checkpoint_dir)
        if step % (5 * hparams.checkpoint_interval) == 0:
            save_checkpoint(student, optimizer, step, checkpoint_dir, epoch)
    if do_eval and False:
        # NOTE: use train step (i.e., global_step) for filename
        # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema)
        eval_model(global_step, writer, student, y, c, g, input_lengths,
                   eval_dir, ema)

    # Update
    if train:
        loss.backward()
        if clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(),
                                                      clip_thresh)
        if gpu_count > 1:
            optimizer.module.step()
        else:
            optimizer.step()
        # update moving average
        if ema is not None:
            for name, param in student.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

    # Logs
    writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step)
    writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step)
    writer.add_scalar("{} h_pt_ps".format(phase), float(cross_entropy.data[0]),
                      step)
    writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step)
    writer.add_scalar("{} power_loss".format(phase), float(power_loss.data[0]),
                      step)
    if train:
        if clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, step)
            # writer.add_scalar("gradient norm", grad_norm, step)
        # writer.add_scalar("learning rate", current_lr, step)

    return loss.data[0]
def __train_step(phase,
                 epoch,
                 global_step,
                 global_test_step,
                 teacher,
                 student,
                 optimizer,
                 writer,
                 x,
                 y,
                 c,
                 g,
                 input_lengths,
                 checkpoint_dir,
                 eval_dir=None,
                 do_eval=False,
                 ema=None):
    sanity_check(teacher, c, g)
    sanity_check(student, c, g)

    # x : (B, C, T)
    # y : (B, T, 1)
    # c : (B, C, T)
    # g : (B,)
    train = (phase == "train")
    clip_thresh = hparams.clip_thresh
    if train:
        teacher.eval()  # set teacher as eval mode
        student.train()
        step = global_step
    else:
        student.eval()
        step = global_test_step

    # ---------------------- the parallel wavenet use constant learning rate = 0.0002
    # Learning rate schedule
    # current_lr = hparams.initial_learning_rate
    # if train and hparams.lr_schedule is not None:
    #     lr_schedule_f = getattr(lrschedule, hparams.lr_schedule)
    #     current_lr = lr_schedule_f(
    #         hparams.initial_learning_rate, step, **hparams.lr_schedule_kwargs)
    #     if gpu_count>1:
    #         for param_group in optimizer.module.param_groups:
    #             param_group['lr'] = current_lr
    #     else:
    #         for param_group in optimizer.param_groups:
    #             param_group['lr'] = current_lr
    optimizer.zero_grad()
    cross_entorpy = nn.CrossEntropyLoss()
    # Prepare data
    x, y = Variable(x), Variable(y, requires_grad=False)
    c = Variable(c) if c is not None else None
    g = Variable(g) if g is not None else None
    input_lengths = Variable(input_lengths)
    if use_cuda:
        x, y = x.cuda(), y.cuda()
        input_lengths = input_lengths.cuda()
        c = c.cuda() if c is not None else None
        g = g.cuda() if g is not None else None

    # (B, T, 1)
    mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1)
    mask = mask[:, 1:, :]
    # mask.expand_as(y)
    # apply the student model with stacked iaf layers and return mu,scale
    z = Variable(
        torch.from_numpy(np.random.logistic(0, 1,
                                            size=x.size())).float()).cuda()
    mu, scale = student(z, c=c, g=g, softmax=False)
    m, s = mu, scale
    mu, scale = to_numpy(mu), to_numpy(scale)
    kl_loss, h_s = 0, 0
    _h_pt_ps = 0
    m = m.clamp(-0.999, 0.999)
    sample_T, kl_loss_sum = 5, Variable(torch.FloatTensor(1).float(),
                                        requires_grad=True).cuda()
    power_loss_sum = 0
    for i in range(sample_T):
        z = np.random.logistic(0, 1, x.shape)
        student_predict = m + s * to_variable(z)  # predicted wave
        # sp = student_predict.clamp(-0.99, 0.99)
        student_predict = student_predict.clamp(-0.99, 0.99)
        y_hat = teacher(student_predict, c=c,
                        g=g)  # y_hat: (B x C x T) teacher: 10-mixture-logistic
        # sample from teacher distribution
        teacher_predict = sample_from_discretized_mix_logistic(y_hat)
        student_predict = student_predict.permute(0, 2, 1)
        _, teacher_log_p = discretized_mix_logistic_loss(
            y_hat[:, :, :-1], student_predict[:,
                                              1:, :], reduce=False)  # -log(Pt)
        # h_pt_ps = torch.sum(teacher_log_p * p_s * mask)  # / mask.sum()
        h_pt_ps = torch.sum(teacher_log_p * mask) / mask.sum()
        # h_pt_ps = F.cross_entropy(student_predict,teacher_predict)
        student_predict = student_predict.permute(0, 2, 1)
        power_loss_sum += get_power_loss_torch(student_predict, x)
        # _h_pt_ps += torch.sum(teacher_log_p)  # / mask.sum()
        a = s.permute(0, 2, 1)
        # h_ps = torch.sum(torch.log(p_s) * mask)  # / mask.sum()
        # cross_entorpy = F.cross_entropy(teacher_predict,student_predict)
        h_ps = torch.sum((teacher_log_p -
                          (torch.log(a[:, 1:, :]) + 2)) * mask) / mask.sum()
        kl_loss_sum += h_ps  #+ h_pt_ps
    kl_loss = kl_loss_sum / (hparams.batch_size * sample_T)
    power_loss = power_loss_sum / (hparams.batch_size * sample_T)
    loss = kl_loss  # + power_loss
    rs = kl_loss.cpu().data.numpy()
    if rs == np.isinf(rs):
        print('inf detected')
    else:
        print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'.
              format(to_numpy(power_loss), np.mean(scale), np.mean(mu),
                     to_numpy(kl_loss), to_numpy(loss)))
    if train and step > 0 and step % hparams.checkpoint_interval == 0:
        save_states(step, writer, y_hat, y, student_predict, input_lengths,
                    checkpoint_dir)
        if step % (5 * hparams.checkpoint_interval) == 0:
            save_checkpoint(student, optimizer, step, checkpoint_dir, epoch)
    if do_eval and False:
        # NOTE: use train step (i.e., global_step) for filename
        # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema)
        eval_model(global_step, writer, student, y, c, g, input_lengths,
                   eval_dir, ema)

    # Update
    if train:
        loss.backward()
        if clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(),
                                                      clip_thresh)
        if gpu_count > 1:
            optimizer.module.step()
        else:
            optimizer.step()
        # update moving average
        if ema is not None:
            for name, param in student.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

    # Logs
    writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step)
    writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step)
    writer.add_scalar("{} h_pt_ps".format(phase), float(h_pt_ps.data[0]), step)
    writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step)
    if train:
        if clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, step)
            # writer.add_scalar("gradient norm", grad_norm, step)
        # writer.add_scalar("learning rate", current_lr, step)

    return loss.data[0]