def save_states(global_step,
                writer,
                y_hat,
                student_hat,
                y,
                input_lengths,
                checkpoint_dir=None):

    print("Save intermediate states at step {}".format(global_step))
    idx = np.random.randint(0, len(y_hat))
    length = input_lengths[idx].data.cpu().item()

    # (B, C, T)
    if y_hat.dim() == 4:
        y_hat = y_hat.squeeze(-1)

    if is_mulaw_quantize(hparams.input_type):
        # (B, T)
        y_hat = F.softmax(y_hat, dim=1).max(1)[1]

        # (T,)
        y_hat = y_hat[idx].data.cpu().long().numpy()
        y = y[idx].view(-1).data.cpu().long().numpy()

        y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels)
        y = P.inv_mulaw_quantize(y, hparams.quantize_channels)
    else:
        # (B, T)
        if hparams.use_gaussian:
            y_hat = y_hat.transpose(1, 2)
            y_hat = sample_from_gaussian(y_hat,
                                         log_scale_min=hparams.log_scale_min)
        else:
            y_hat = sample_from_discretized_mix_logistic(
                y_hat, log_scale_min=hparams.log_scale_min)

        # (T,)
        y_hat = y_hat[idx].view(-1).data.cpu().numpy()
        y = y[idx].view(-1).data.cpu().numpy()
        student_hat = student_hat[idx].view(-1).data.cpu().numpy()

        if is_mulaw(hparams.input_type):
            y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels)
            y = P.inv_mulaw(y, hparams.quantize_channels)
            student_hat = P.inv_mulaw(student_hat, hparams.quantize_channels)

    # Mask by length
    y_hat[length:] = 0
    y[length:] = 0
    student_hat[length:] = 0

    # Save audio
    audio_dir = join(checkpoint_dir, "audio")
    os.makedirs(audio_dir, exist_ok=True)
    path = join(audio_dir, "step{:09d}_teacher.wav".format(global_step))
    librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_student.wav".format(global_step))
    librosa.output.write_wav(path, student_hat, sr=hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y, sr=hparams.sample_rate)
def eval_model(global_step,
               writer,
               device,
               student,
               teacher,
               y,
               c,
               g,
               input_lengths,
               eval_dir,
               ema=None):
    if ema is not None:
        print("Using averaged model for evaluation")
        student = clone_as_averaged_model(device, student, ema)
        student.make_generation_fast_()

    student.eval()
    teacher.eval()
    idx = np.random.randint(0, len(y))
    length = input_lengths[idx].data.cpu().item()

    # (T,)
    y_target = y[idx].view(-1).data.cpu().numpy()[:length]

    if c is not None:
        if hparams.upsample_conditional_features:
            c = c[idx, :, :length // audio.get_hop_size()].unsqueeze(0)
        else:
            c = c[idx, :, :length].unsqueeze(0)
        assert c.dim() == 3
        print("Shape of local conditioning features: {}".format(c.size()))
    if g is not None:
        # TODO: test
        g = g[idx]
        print("Shape of global conditioning features: {}".format(g.size()))

    # noise input
    dist = torch.distributions.normal.Normal(loc=0., scale=1.)
    z = dist.sample((1, 1, length)).to(device)

    # Run the model
    with torch.no_grad():
        student_hat, _, _, _ = student(x=z,
                                       c=c,
                                       g=g,
                                       log_scale_min=hparams.log_scale_min,
                                       device=device)
        teacher_output = teacher(student_hat, c=c, g=g, softmax=False)
        teacher_output = teacher_output.transpose(1, 2)
        teacher_hat = sample_from_gaussian(teacher_output,
                                           log_scale_min=hparams.log_scale_min)

    # if is_mulaw_quantize(hparams.input_type):
    #     y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy()
    #     y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels)
    #     y_target = P.inv_mulaw_quantize(y_target, hparams.quantize_channels)
    # elif is_mulaw(hparams.input_type):
    #     y_hat = P.inv_mulaw(y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels)
    #     y_target = P.inv_mulaw(y_target, hparams.quantize_channels)
    # else:
    #     y_hat = y_hat.view(-1).cpu().data.numpy()

    teacher_hat = teacher_hat.view(-1).cpu().data.numpy()
    student_hat = student_hat.view(-1).cpu().data.numpy()

    # Save audio
    os.makedirs(eval_dir, exist_ok=True)
    path = join(eval_dir, "step{:09d}_student.wav".format(global_step))
    librosa.output.write_wav(path, student_hat, sr=hparams.sample_rate)
    path = join(eval_dir, "step{:09d}_teacher.wav".format(global_step))
    librosa.output.write_wav(path, teacher_hat, sr=hparams.sample_rate)
    path = join(eval_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y_target, sr=hparams.sample_rate)

    # save figure
    path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step))
    save_waveplot(path, teacher_hat, y_target, student_hat)