def save_states(global_step,
                writer,
                y_hat,
                student_hat,
                y,
                input_lengths,
                checkpoint_dir=None):

    print("Save intermediate states at step {}".format(global_step))
    idx = np.random.randint(0, len(y_hat))
    length = input_lengths[idx].data.cpu().item()

    # (B, C, T)
    if y_hat.dim() == 4:
        y_hat = y_hat.squeeze(-1)

    if is_mulaw_quantize(hparams.input_type):
        # (B, T)
        y_hat = F.softmax(y_hat, dim=1).max(1)[1]

        # (T,)
        y_hat = y_hat[idx].data.cpu().long().numpy()
        y = y[idx].view(-1).data.cpu().long().numpy()

        y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels)
        y = P.inv_mulaw_quantize(y, hparams.quantize_channels)
    else:
        # (B, T)
        if hparams.use_gaussian:
            y_hat = y_hat.transpose(1, 2)
            y_hat = sample_from_gaussian(y_hat,
                                         log_scale_min=hparams.log_scale_min)
        else:
            y_hat = sample_from_discretized_mix_logistic(
                y_hat, log_scale_min=hparams.log_scale_min)

        # (T,)
        y_hat = y_hat[idx].view(-1).data.cpu().numpy()
        y = y[idx].view(-1).data.cpu().numpy()
        student_hat = student_hat[idx].view(-1).data.cpu().numpy()

        if is_mulaw(hparams.input_type):
            y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels)
            y = P.inv_mulaw(y, hparams.quantize_channels)
            student_hat = P.inv_mulaw(student_hat, hparams.quantize_channels)

    # Mask by length
    y_hat[length:] = 0
    y[length:] = 0
    student_hat[length:] = 0

    # Save audio
    audio_dir = join(checkpoint_dir, "audio")
    os.makedirs(audio_dir, exist_ok=True)
    path = join(audio_dir, "step{:09d}_teacher.wav".format(global_step))
    librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_student.wav".format(global_step))
    librosa.output.write_wav(path, student_hat, sr=hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y, sr=hparams.sample_rate)
Exemple #2
0
def save_states(global_step,
                writer,
                y_hat,
                y,
                input_lengths,
                checkpoint_dir=None):
    print("Save intermediate states at step {}".format(global_step))
    idx = np.random.randint(0, len(y_hat))
    length = input_lengths[idx].data.cpu().item()

    # (B, C, T)
    if y_hat.dim() == 4:
        y_hat = y_hat.squeeze(-1)

    if is_mulaw_quantize(wavenet_hparams.input_type):
        # (B, T)
        y_hat = F.softmax(y_hat, dim=1).max(1)[1]

        # (T,)
        y_hat = y_hat[idx].data.cpu().long().numpy()
        y = y[idx].view(-1).data.cpu().long().numpy()

        y_hat = P.inv_mulaw_quantize(y_hat,
                                     wavenet_hparams.quantize_channels - 1)
        y = P.inv_mulaw_quantize(y, wavenet_hparams.quantize_channels - 1)
    else:
        # (B, T)
        if wavenet_hparams.output_distribution == "Logistic":
            y_hat = sample_from_discretized_mix_logistic(
                y_hat, log_scale_min=wavenet_hparams.log_scale_min)
        elif wavenet_hparams.output_distribution == "Normal":
            y_hat = sample_from_mix_gaussian(
                y_hat, log_scale_min=wavenet_hparams.log_scale_min)
        else:
            assert False

        # (T,)
        y_hat = y_hat[idx].view(-1).data.cpu().numpy()
        y = y[idx].view(-1).data.cpu().numpy()

        if is_mulaw(wavenet_hparams.input_type):
            y_hat = P.inv_mulaw(y_hat, wavenet_hparams.quantize_channels)
            y = P.inv_mulaw(y, wavenet_hparams.quantize_channels)

    # Mask by length
    y_hat[length:] = 0
    y[length:] = 0

    # Save audio
    audio_dir = join(checkpoint_dir, "intermediate", "audio")
    os.makedirs(audio_dir, exist_ok=True)
    path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step))
    # librosa.output.write_wav(path, y_hat, sr=wavenet_hparams.sample_rate)
    sf.write(path, y_hat, samplerate=wavenet_hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_target.wav".format(global_step))
    # librosa.output.write_wav(path, y, sr=wavenet_hparams.sample_rate)
    sf.write(path, y, samplerate=wavenet_hparams.sample_rate)
def test_mixture():
    np.random.seed(1234)

    x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None)
    assert sr == 16000

    T = len(x)
    x = x.reshape(1, T, 1)
    y = Variable(torch.from_numpy(x)).float()
    y_hat = Variable(torch.rand(1, 30, T)).float()

    print(y.shape, y_hat.shape)

    loss = discretized_mix_logistic_loss(y_hat, y)
    print(loss)

    loss = discretized_mix_logistic_loss(y_hat, y, reduce=False)
    print(loss.size(), y.size())
    assert loss.size() == y.size()

    y = sample_from_discretized_mix_logistic(y_hat)
    print(y.shape)
Exemple #4
0
def save_states(global_step,
                writer,
                y_hat,
                y,
                y_student,
                input_lengths,
                mu=None,
                checkpoint_dir=None):
    '''

    :param global_step:
    :param writer:
    :param y_hat: parameters output by teachery_hat是教师结果
    :param y: target
    :param y_student: student output
    :param input_lengths:
    :param mu: student mu
    :param checkpoint_dir:
    :return:
    '''
    print("Save intermediate states at step {}".format(global_step))
    idx = np.random.randint(0, len(y_hat))
    length = input_lengths[idx].data.cpu().numpy()
    if mu is not None:
        mu = mu[idx]
    # (B, C, T)
    if y_hat.dim() == 4:
        y_hat = y_hat.squeeze(-1)

    if is_mulaw_quantize(hparams.input_type):
        # (B, T)
        y_hat = F.softmax(y_hat, dim=1).max(1)[1]

        # (T,)
        y_hat = y_hat[idx].data.cpu().long().numpy()
        y = y[idx].view(-1).data.cpu().long().numpy()

        y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels)
        y = P.inv_mulaw_quantize(y, hparams.quantize_channels)
    else:
        # (B, T)
        y_hat = sample_from_discretized_mix_logistic(
            y_hat, log_scale_min=hparams.log_scale_min)
        # (T,)
        y_hat = y_hat[idx].view(-1).data.cpu().numpy()
        y = y[idx].view(-1).data.cpu().numpy()

        if is_mulaw(hparams.input_type):
            y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels)
            y = P.inv_mulaw(y, hparams.quantize_channels)

    # Mask by length
    y_hat[length:] = 0
    y[length:] = 0
    y_student = y_student.data.cpu().numpy()
    y_student = y_student[idx].reshape(y_student.shape[-1])
    mu = to_numpy(mu)
    # Save audio
    audio_dir = join(checkpoint_dir, "audio")
    if global_step % 1000 == 0:
        audio_dir = join(checkpoint_dir, "audio")
        os.makedirs(audio_dir, exist_ok=True)
        path = join(audio_dir, "step{:09d}_teacher.wav".format(global_step))
        librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
        path = join(audio_dir, "step{:09d}_target.wav".format(global_step))
        librosa.output.write_wav(path, y, sr=hparams.sample_rate)
        path = join(audio_dir, "step{:09d}_student.wav".format(global_step))
        librosa.output.write_wav(path, y_student, sr=hparams.sample_rate)
    # TODO save every 200 step,
    if global_step % 200 == 0:
        path = join(audio_dir, "wave_step{:09d}.png".format(global_step))
        save_waveplot(path,
                      y_student=y_student,
                      y_target=y,
                      y_teacher=y_hat,
                      student_mu=mu)
Exemple #5
0
    def incremental_forward(self,
                            initial_input=None,
                            c=None,
                            g=None,
                            T=100,
                            test_inputs=None,
                            tqdm=lambda x: x,
                            softmax=True,
                            quantize=True,
                            log_scale_min=-50.0):
        """Incremental forward step

        Due to linearized convolutions, inputs of shape (B x C x T) are reshaped
        to (B x T x C) internally and fed to the network for each time step.
        Input of each time step will be of shape (B x 1 x C).

        Args:
            initial_input (Tensor): Initial decoder input, (B x C x 1)
            c (Tensor): Local conditioning features, shape (B x C' x T)
            g (Tensor): Global conditioning features, shape (B x C'' or B x C''x 1)
            T (int): Number of time steps to generate.
            test_inputs (Tensor): Teacher forcing inputs (for debugging)
            tqdm (lamda) : tqdm
            softmax (bool) : Whether applies softmax or not
            quantize (bool): Whether quantize softmax output before feeding the
              network output to input for the next time step. TODO: rename
            log_scale_min (float):  Log scale minimum value.

        Returns:
            Tensor: Generated one-hot encoded samples. B x C x T 
              or scaler vector B x 1 x T
        """
        self.clear_buffer()
        B = 1

        # Note: shape should be **(B x T x C)**, not (B x C x T) opposed to
        # batch forward due to linealized convolution
        if test_inputs is not None:
            if self.scalar_input:
                if test_inputs.size(1) == 1:
                    test_inputs = test_inputs.transpose(1, 2).contiguous()
            else:
                if test_inputs.size(1) == self.out_channels:
                    test_inputs = test_inputs.transpose(1, 2).contiguous()

            B = test_inputs.size(0)
            if T is None:
                T = test_inputs.size(1)
            else:
                T = max(T, test_inputs.size(1))
        # cast to int in case of numpy.int64...
        T = int(T)

        # Global conditioning
        if g is not None:
            if self.embed_speakers is not None:
                g = self.embed_speakers(g.view(B, -1))
                # (B x gin_channels, 1)
                g = g.transpose(1, 2)
                assert g.dim() == 3
        g_btc = _expand_global_features(B, T, g, bct=False)

        # Local conditioning
        if c is not None:
            B = c.shape[0]
            if self.upsample_net is not None:
                c = self.upsample_net(c)
                assert c.size(-1) == T
            if c.size(-1) == T:
                c = c.transpose(1, 2).contiguous()

        outputs = []
        if initial_input is None:
            if self.scalar_input:
                initial_input = torch.zeros(B, 1, 1)
            else:
                initial_input = torch.zeros(B, 1, self.out_channels)
                initial_input[:, :, 127] = 1  # TODO: is this ok?
            # https://github.com/pytorch/pytorch/issues/584#issuecomment-275169567
            if next(self.parameters()).is_cuda:
                initial_input = initial_input.cuda()
        else:
            if initial_input.size(1) == self.out_channels:
                initial_input = initial_input.transpose(1, 2).contiguous()

        current_input = initial_input

        for t in tqdm(range(T)):
            if test_inputs is not None and t < test_inputs.size(1):
                current_input = test_inputs[:, t, :].unsqueeze(1)
            else:
                if t > 0:
                    current_input = outputs[-1]

            # Conditioning features for single time step
            ct = None if c is None else c[:, t, :].unsqueeze(1)
            gt = None if g is None else g_btc[:, t, :].unsqueeze(1)

            x = current_input
            x = self.first_conv.incremental_forward(x)
            skips = 0
            for f in self.conv_layers:
                x, h = f.incremental_forward(x, ct, gt)
                skips += h
            skips *= math.sqrt(1.0 / len(self.conv_layers))
            x = skips
            for f in self.last_conv_layers:
                try:
                    x = f.incremental_forward(x)
                except AttributeError:
                    x = f(x)

            # Generate next input by sampling
            if self.scalar_input:
                if self.output_distribution == "Logistic":
                    x = sample_from_discretized_mix_logistic(
                        x.view(B, -1, 1), log_scale_min=log_scale_min)
                elif self.output_distribution == "Normal":
                    x = sample_from_mix_gaussian(x.view(B, -1, 1),
                                                 log_scale_min=log_scale_min)
                else:
                    assert False
            else:
                x = F.softmax(x.view(B, -1), dim=1) if softmax else x.view(
                    B, -1)
                if quantize:
                    dist = torch.distributions.OneHotCategorical(x)
                    x = dist.sample()
            outputs += [x.data]
        # T x B x C
        outputs = torch.stack(outputs)
        # B x C x T
        outputs = outputs.transpose(0, 1).transpose(1, 2).contiguous()

        self.clear_buffer()
        return outputs
def __train_step(phase,
                 epoch,
                 global_step,
                 global_test_step,
                 teacher,
                 student,
                 optimizer,
                 writer,
                 x,
                 y,
                 c,
                 g,
                 input_lengths,
                 checkpoint_dir,
                 eval_dir=None,
                 do_eval=False,
                 ema=None):
    sanity_check(teacher, c, g)
    sanity_check(student, c, g)

    # x : (B, C, T)
    # y : (B, T, 1)
    # c : (B, C, T)
    # g : (B,)
    train = (phase == "train")
    clip_thresh = hparams.clip_thresh
    if train:
        teacher.eval()  # set teacher as eval mode
        student.train()
        step = global_step
    else:
        student.eval()
        step = global_test_step

    # ---------------------- the parallel wavenet use constant learning rate = 0.0002
    # Learning rate schedule
    # current_lr = hparams.initial_learning_rate
    # if train and hparams.lr_schedule is not None:
    #     lr_schedule_f = getattr(lrschedule, hparams.lr_schedule)
    #     current_lr = lr_schedule_f(
    #         hparams.initial_learning_rate, step, **hparams.lr_schedule_kwargs)
    #     if gpu_count>1:
    #         for param_group in optimizer.module.param_groups:
    #             param_group['lr'] = current_lr
    #     else:
    #         for param_group in optimizer.param_groups:
    #             param_group['lr'] = current_lr
    optimizer.zero_grad()
    cross_entorpy = nn.CrossEntropyLoss()
    # Prepare data
    x, y = Variable(x), Variable(y, requires_grad=False)
    c = Variable(c) if c is not None else None
    g = Variable(g) if g is not None else None
    input_lengths = Variable(input_lengths)
    if use_cuda:
        x, y = x.cuda(), y.cuda()
        input_lengths = input_lengths.cuda()
        c = c.cuda() if c is not None else None
        g = g.cuda() if g is not None else None

    # (B, T, 1)
    mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1)
    mask = mask[:, 1:, :]
    # mask.expand_as(y)
    # apply the student model with stacked iaf layers and return mu,scale
    z = Variable(
        torch.from_numpy(np.random.logistic(0, 1,
                                            size=x.size())).float()).cuda()
    mu, scale = student(z, c=c, g=g, softmax=False)
    m, s = mu, scale
    mu, scale = to_numpy(mu), to_numpy(scale)
    kl_loss, h_s = 0, 0
    _h_pt_ps = 0
    m = m.clamp(-0.999, 0.999)
    sample_T, kl_loss_sum = 5, Variable(torch.FloatTensor(1).float(),
                                        requires_grad=True).cuda()
    power_loss_sum = 0
    for i in range(sample_T):
        z = np.random.logistic(0, 1, x.shape)
        student_predict = m + s * to_variable(z)  # predicted wave
        # sp = student_predict.clamp(-0.99, 0.99)
        student_predict = student_predict.clamp(-0.99, 0.99)
        y_hat = teacher(student_predict, c=c,
                        g=g)  # y_hat: (B x C x T) teacher: 10-mixture-logistic
        # sample from teacher distribution
        teacher_predict = sample_from_discretized_mix_logistic(y_hat)
        student_predict = student_predict.permute(0, 2, 1)
        _, teacher_log_p = discretized_mix_logistic_loss(
            y_hat[:, :, :-1], student_predict[:,
                                              1:, :], reduce=False)  # -log(Pt)
        # h_pt_ps = torch.sum(teacher_log_p * p_s * mask)  # / mask.sum()
        h_pt_ps = torch.sum(teacher_log_p * mask) / mask.sum()
        # h_pt_ps = F.cross_entropy(student_predict,teacher_predict)
        student_predict = student_predict.permute(0, 2, 1)
        power_loss_sum += get_power_loss_torch(student_predict, x)
        # _h_pt_ps += torch.sum(teacher_log_p)  # / mask.sum()
        a = s.permute(0, 2, 1)
        # h_ps = torch.sum(torch.log(p_s) * mask)  # / mask.sum()
        # cross_entorpy = F.cross_entropy(teacher_predict,student_predict)
        h_ps = torch.sum((teacher_log_p -
                          (torch.log(a[:, 1:, :]) + 2)) * mask) / mask.sum()
        kl_loss_sum += h_ps  #+ h_pt_ps
    kl_loss = kl_loss_sum / (hparams.batch_size * sample_T)
    power_loss = power_loss_sum / (hparams.batch_size * sample_T)
    loss = kl_loss  # + power_loss
    rs = kl_loss.cpu().data.numpy()
    if rs == np.isinf(rs):
        print('inf detected')
    else:
        print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'.
              format(to_numpy(power_loss), np.mean(scale), np.mean(mu),
                     to_numpy(kl_loss), to_numpy(loss)))
    if train and step > 0 and step % hparams.checkpoint_interval == 0:
        save_states(step, writer, y_hat, y, student_predict, input_lengths,
                    checkpoint_dir)
        if step % (5 * hparams.checkpoint_interval) == 0:
            save_checkpoint(student, optimizer, step, checkpoint_dir, epoch)
    if do_eval and False:
        # NOTE: use train step (i.e., global_step) for filename
        # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema)
        eval_model(global_step, writer, student, y, c, g, input_lengths,
                   eval_dir, ema)

    # Update
    if train:
        loss.backward()
        if clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(),
                                                      clip_thresh)
        if gpu_count > 1:
            optimizer.module.step()
        else:
            optimizer.step()
        # update moving average
        if ema is not None:
            for name, param in student.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

    # Logs
    writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step)
    writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step)
    writer.add_scalar("{} h_pt_ps".format(phase), float(h_pt_ps.data[0]), step)
    writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step)
    if train:
        if clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, step)
            # writer.add_scalar("gradient norm", grad_norm, step)
        # writer.add_scalar("learning rate", current_lr, step)

    return loss.data[0]