Пример #1
0
def save_states(global_step,
                writer,
                y_hat,
                y,
                input_lengths,
                checkpoint_dir=None):
    print("Save intermediate states at step {}".format(global_step))
    idx = np.random.randint(0, len(y_hat))
    length = input_lengths[idx].data.cpu().item()

    # (B, C, T)
    if y_hat.dim() == 4:
        y_hat = y_hat.squeeze(-1)

    if is_mulaw_quantize(hparams.input_type) or is_linear_quantize(
            hparams.input_type):
        # (B, T)
        y_hat = F.softmax(y_hat, dim=1).max(1)[1]

        # (T,)
        y_hat = y_hat[idx].data.cpu().long().numpy()
        y = y[idx].view(-1).data.cpu().long().numpy()

        if is_mulaw_quantize(hparams.input_type):
            y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1)
            y = P.inv_mulaw_quantize(y, hparams.quantize_channels - 1)
        elif is_linear_quantize(hparams.input_type):
            y_hat = inv_linear_quantize(y_hat, hparams.quantize_channels - 1)
            y = inv_linear_quantize(y, hparams.quantize_channels - 1)
    else:
        # (B, T)
        if hparams.output_distribution == "Logistic":
            y_hat = sample_from_discretized_mix_logistic(
                y_hat, log_scale_min=hparams.log_scale_min)
        elif hparams.output_distribution == "Normal":
            y_hat = sample_from_mix_gaussian(
                y_hat, log_scale_min=hparams.log_scale_min)
        else:
            assert False

        # (T,)
        y_hat = y_hat[idx].view(-1).data.cpu().numpy()
        y = y[idx].view(-1).data.cpu().numpy()

        if is_mulaw(hparams.input_type):
            y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels)
            y = P.inv_mulaw(y, hparams.quantize_channels)

    # Mask by length
    y_hat[length:] = 0
    y[length:] = 0

    # Save audio
    audio_dir = join(checkpoint_dir, "intermediate", "audio")
    os.makedirs(audio_dir, exist_ok=True)
    path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step))
    librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
    path = join(audio_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y, sr=hparams.sample_rate)
Пример #2
0
def batch_wavegen(model, c=None, g=None, fast=True, tqdm=tqdm, length=None, writing_dir=None):
    from train import sanity_check
    sanity_check(model, c, g)
    # assert c is not None
    if c is not None:
        B = c.shape[0]
    else:
        B = 1 #c.shape[0]
    model.eval()
    if fast:
        model.make_generation_fast_()

    # Transform data to GPU
    g = None if g is None else g.to(device)
    c = None if c is None else c.to(device)

    if hparams.upsample_conditional_features and length is None:
        length = (c.shape[-1] - hparams.cin_pad * 2) * audio.get_hop_size()

    with torch.no_grad():
        y_hat = model.incremental_forward(
            c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True,
            log_scale_min=hparams.log_scale_min)


        y_hat_sample = y_hat.max(1)[1].view(B, -1).float()
        cross_entropy = model.binary_softmax_loss(y_hat_sample.unsqueeze(1), c)

    # Write the output
    with open(join(writing_dir, "info.json"), "w") as f:
        data = {"0.244" : float(cross_entropy.detach().cpu().numpy())}
        json.dump(data, f, indent=4)

    if is_mulaw_quantize(hparams.input_type):
        # needs to be float since mulaw_inv returns in range of [-1, 1]
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw_quantize(y_hat[i], hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = inv_linear_quantize(y_hat[i], hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = y_hat.view(B, -1).cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw(y_hat[i], hparams.quantize_channels - 1)
    else:
        y_hat = y_hat.view(B, -1).cpu().data.numpy()

    if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]:
        for i in range(B):
            y_hat[i] = getattr(audio, hparams.postprocess)(y_hat[i])

    if hparams.global_gain_scale > 0:
        for i in range(B):
            y_hat[i] /= hparams.global_gain_scale

    return y_hat
Пример #3
0
def batch_wavegen(model, c=None, g=None, fast=True, tqdm=tqdm, length=None):
    from train import sanity_check
    sanity_check(model, c, g)
    # assert c is not None
    if c is not None:
        B = c.shape[0]
    else:
        B = 1  #c.shape[0]
    model.eval()
    if fast:
        model.make_generation_fast_()

    # Transform data to GPU
    g = None if g is None else g.to(device)
    c = None if c is None else c.to(device)

    if hparams.upsample_conditional_features and length is None:
        length = (c.shape[-1] - hparams.cin_pad * 2) * audio.get_hop_size()

    with torch.no_grad():
        y_hat = model.incremental_forward(c=c,
                                          g=g,
                                          T=length,
                                          tqdm=tqdm,
                                          softmax=True,
                                          quantize=True,
                                          log_scale_min=hparams.log_scale_min)

    if is_mulaw_quantize(hparams.input_type):
        # needs to be float since mulaw_inv returns in range of [-1, 1]
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw_quantize(y_hat[i],
                                            hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = inv_linear_quantize(y_hat[i],
                                           hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = y_hat.view(B, -1).cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw(y_hat[i], hparams.quantize_channels - 1)
    else:
        y_hat = y_hat.view(B, -1).cpu().data.numpy()

    if hparams.postprocess is not None and hparams.postprocess not in [
            "", "none"
    ]:
        for i in range(B):
            y_hat[i] = getattr(audio, hparams.postprocess)(y_hat[i])

    if hparams.global_gain_scale > 0:
        for i in range(B):
            y_hat[i] /= hparams.global_gain_scale

    return y_hat
Пример #4
0
def main(args):
    model0 = ModelWrapper()
    model1 = ModelWrapper()

    receptive_field = model0.receptive_field

    writing_dir = args["<output-dir>"]
    os.makedirs(writing_dir, exist_ok=True)
    print("writing dir: {}".format(writing_dir))

    source1 = librosa.core.load(args["<input-file1>"], sr=22050, mono=True)[0]
    source2 = librosa.core.load(args["<input-file2>"], sr=22050, mono=True)[0]
    mixed = source1 + source2

    # Increase the volume of the mixture fo avoid artifacts from linear encoding
    mixed /= abs(mixed).max()
    mixed *= 1.4
    mixed = linear_quantize(mixed + 1.0, hparams.quantize_channels - 1)
    global SAMPLE_SIZE
    if SAMPLE_SIZE == -1:
        SAMPLE_SIZE = int(mixed.shape[0])

    mixed = mixed[:SAMPLE_SIZE]

    mixed = torch.FloatTensor(mixed).reshape(1, -1).to(device)

    # Write inputs
    mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(),
                                    hparams.quantize_channels - 1) - 1.0
    mixed_out = np.clip(mixed_out, -1, 1)
    sf.write(join(writing_dir, "mixed.wav"), mixed_out, hparams.sample_rate)

    # Initialize with noise
    x0 = torch.FloatTensor(np.random.uniform(0, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x0[:] = mixed - 127.0
    x0 = F.pad(x0, (receptive_field, receptive_field), "constant", 127)
    x0.requires_grad = True

    x1 = torch.FloatTensor(np.random.uniform(0, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x1[:] = 127.
    x1 = F.pad(x1, (receptive_field, receptive_field), "constant", 127)
    x1.requires_grad = True

    sigmas = [
        175.9, 110., 68.7, 54.3, 42.9, 34.0, 26.8, 21.2, 16.8, 13.3, 10.5,
        8.29, 6.55, 5.18, 4.1, 3.24, 2.56, 1.6, 1.0, 0.625, 0.39, 0.244, 0.15,
        0.1
    ]

    np.random.seed(999)

    for idx, sigma in enumerate(sigmas):
        # We make sure each sample is updated a certain number of times
        n_steps = int((SAMPLE_SIZE / (SGLD_WINDOW * BATCH_SIZE)) * N_STEPS)
        print("Number of SGLD steps {}".format(n_steps))
        # Bump down a model
        checkpoint_path0 = join(args["<checkpoint0>"], CHECKPOINTS[sigma],
                                "checkpoint_latest_ema.pth")
        model0.load_checkpoint(checkpoint_path0)
        checkpoint_path1 = join(args["<checkpoint1>"], CHECKPOINTS[sigma],
                                "checkpoint_latest_ema.pth")
        model1.load_checkpoint(checkpoint_path1)

        parmodel0 = torch.nn.DataParallel(model0)
        parmodel0.to(device)
        parmodel1 = torch.nn.DataParallel(model1)
        parmodel1.to(device)

        eta = .05 * (sigma**2)
        gamma = 15 * (1.0 / sigma)**2

        t0 = time.time()
        for i in range(n_steps):
            # need to get a good sampling of the beginning/end (boundary effects)
            # to understand this: think about how often we would update x[receptive_field] (first point)
            # if we only sampled U(receptive_field,x0.shape-receptive_field-SGLD_WINDOW)
            j = np.random.randint(receptive_field - SGLD_WINDOW,
                                  x0.shape[1] - receptive_field, BATCH_SIZE)
            j = np.maximum(j, receptive_field)
            j = np.minimum(j, x0.shape[1] - (SGLD_WINDOW + receptive_field))

            # Seed with noised up silence
            x0[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, :receptive_field].shape)).to(device)
            x0[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, -receptive_field:].shape)).to(device)
            x1[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, :receptive_field].shape)).to(device)
            x1[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, -receptive_field:].shape)).to(device)

            patches0 = []
            patches1 = []
            mixpatch = []
            for k in range(BATCH_SIZE):
                patches0.append(x0[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                patches1.append(x1[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                mixpatch.append(mixed[:, j[k] - receptive_field:j[k] -
                                      receptive_field + SGLD_WINDOW])

            patches0 = torch.stack(patches0, axis=0)
            patches1 = torch.stack(patches1, axis=0)
            mixpatch = torch.stack(mixpatch, axis=0)

            # Forward pass
            log_prob, prediction0 = parmodel0(patches0, sigma=sigma)
            log_prob0 = torch.sum(log_prob)
            grad0 = torch.autograd.grad(log_prob0, x0)[0]

            log_prob, prediction1 = parmodel1(patches1, sigma=sigma)
            log_prob1 = torch.sum(log_prob)
            grad1 = torch.autograd.grad(log_prob1, x1)[0]

            x0_update, x1_update = [], []
            for k in range(BATCH_SIZE):
                x0_update.append(eta * grad0[:, j[k]:j[k] + SGLD_WINDOW])
                x1_update.append(eta * grad1[:, j[k]:j[k] + SGLD_WINDOW])

            # Langevin step
            for k in range(BATCH_SIZE):
                epsilon0 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x0_update[k] += epsilon0

                epsilon1 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x1_update[k] += epsilon1

            # Reconstruction step
            for k in range(BATCH_SIZE):
                x0_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])
                x1_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])

            with torch.no_grad():
                for k in range(BATCH_SIZE):
                    x0[:, j[k]:j[k] + SGLD_WINDOW] += x0_update[k]
                    x1[:, j[k]:j[k] + SGLD_WINDOW] += x1_update[k]

            if (not i % 40) or (i == (n_steps - 1)):  # debugging
                print("--------------")
                print('sigma = {}'.format(sigma))
                print('eta = {}'.format(eta))
                print("i {}".format(i))
                print("Max sample {}".format(abs(x0).max()))
                print('Mean sample logpx: {}'.format(
                    log_prob0 / (BATCH_SIZE * SGLD_WINDOW)))
                print('Mean sample logpy: {}'.format(
                    log_prob1 / (BATCH_SIZE * SGLD_WINDOW)))
                print("Max gradient update: {}".format(eta * abs(grad0).max()))
                print("Reconstruction: {}".format(
                    abs(x0[:, receptive_field:-receptive_field] +
                        x1[:, receptive_field:-receptive_field] -
                        mixed).mean()))
                print('Elapsed time = {}'.format(time.time() - t0))
                t0 = time.time()

        out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out0 = np.clip(out0, -1, 1)
        sf.write(join(writing_dir, "out0_{}.wav".format(sigma)), out0,
                 hparams.sample_rate)

        out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out1 = np.clip(out1, -1, 1)
        sf.write(join(writing_dir, "out1_{}.wav".format(sigma)), out1,
                 hparams.sample_rate)
Пример #5
0
import librosa
import numpy as np
from glob import glob
import os

from wavenet_vocoder.util import inv_linear_quantize

in_dir = "/projects/grail/vjayaram/wavenet_vocoder/egs/linear_quantize/drums/dump/train_no_dev/"
extension = "*-wave.npy"
src_files = sorted(
    glob(os.path.join(in_dir, "**/") + extension, recursive=True))

all_data = []
for file_name in src_files:
    data = np.load(file_name)
    data = inv_linear_quantize(data, 255)
    all_data.append(np.abs(data))
    print(np.percentile(np.concatenate(all_data).flatten(), 99.5))
Пример #6
0
def eval_model(global_step,
               writer,
               device,
               model,
               y,
               c,
               g,
               input_lengths,
               eval_dir,
               ema=None):
    if ema is not None:
        print("Using averaged model for evaluation")
        model = clone_as_averaged_model(device, model, ema)
        model.make_generation_fast_()

    model.eval()
    idx = np.random.randint(0, len(y))
    length = input_lengths[idx].data.cpu().item()

    # (T,)
    y_target = y[idx].view(-1).data.cpu().numpy()[:length]

    if c is not None:
        if hparams.upsample_conditional_features:
            c = c[idx, :, :length // audio.get_hop_size() +
                  hparams.cin_pad * 2].unsqueeze(0)
        else:
            c = c[idx, :, :length].unsqueeze(0)
        assert c.dim() == 3
        print("Shape of local conditioning features: {}".format(c.size()))
    if g is not None:
        # TODO: test
        g = g[idx]
        print("Shape of global conditioning features: {}".format(g.size()))

    # Dummy silence
    if is_mulaw_quantize(hparams.input_type):
        initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        initial_value = linear_quantize(0, hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        initial_value = P.mulaw(0.0, hparams.quantize_channels)
    else:
        initial_value = 0.0

    # (C,)
    if (is_mulaw_quantize(hparams.input_type) or is_linear_quantize(
            hparams.input_type)) and not hparams.manual_scalar_input:
        initial_input = to_categorical(
            initial_value,
            num_classes=hparams.quantize_channels).astype(np.float32)
        initial_input = torch.from_numpy(initial_input).view(
            1, 1, hparams.quantize_channels)
    else:
        initial_input = torch.zeros(1, 1, 1).fill_(initial_value)
    initial_input = initial_input.to(device)

    # Run the model in fast eval mode
    with torch.no_grad():
        y_hat = model.incremental_forward(initial_input,
                                          c=c,
                                          g=g,
                                          T=length,
                                          softmax=True,
                                          quantize=True,
                                          tqdm=tqdm,
                                          log_scale_min=hparams.log_scale_min)

    if is_mulaw_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy()
        y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1)
        y_target = P.inv_mulaw_quantize(y_target,
                                        hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy()
        y_hat = inv_linear_quantize(y_hat, hparams.quantize_channels - 1)
        y_target = inv_linear_quantize(y_target, hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = P.inv_mulaw(
            y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels)
        y_target = P.inv_mulaw(y_target, hparams.quantize_channels)
    else:
        y_hat = y_hat.view(-1).cpu().data.numpy()

    # Save audio
    os.makedirs(eval_dir, exist_ok=True)
    path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step))
    librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
    path = join(eval_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y_target, sr=hparams.sample_rate)

    # save figure
    path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step))
    save_waveplot(path, y_hat, y_target)
Пример #7
0
def main(args):
    model0 = build_model().to(device)
    model0.eval()

    model1 = build_model().to(device)
    model1.eval()
    receptive_field = model0.receptive_field

    x0_original = np.load("supra_piano/dump/dev/zf882fv0052-wave.npy")
    x0_original = x0_original[200000:200000 + SAMPLE_SIZE]

    # x0_original = np.load("vctk/dump/dev/p374_422-wave.npy")
    # x0_original = x0_original[20000:20000 + SAMPLE_SIZE]

    x1_original = np.load("vctk/dump/dev/p341_048-wave.npy")
    x1_original = x1_original[32000:32000 + SAMPLE_SIZE]

    mixed = torch.FloatTensor(x0_original + x1_original).reshape(1,
                                                                 -1).to(device)

    # Write inputs
    mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(),
                                    hparams.quantize_channels - 1) - 1.0
    mixed_out = np.clip(mixed_out, -1, 1)
    sf.write("mixed.wav", mixed_out, hparams.sample_rate)

    x0_original_out = inv_linear_quantize(x0_original,
                                          hparams.quantize_channels - 1)
    sf.write("x0_original.wav", x0_original_out, hparams.sample_rate)

    x1_original_out = inv_linear_quantize(x1_original,
                                          hparams.quantize_channels - 1)
    sf.write("x1_original.wav", x1_original_out, hparams.sample_rate)

    # Initialize with noise
    x0 = torch.FloatTensor(np.random.uniform(-256, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x0 = F.pad(x0, (receptive_field, 0), "constant", 127)
    x0.requires_grad = True

    x1 = torch.FloatTensor(np.random.uniform(-256, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x1 = F.pad(x1, (receptive_field, 0), "constant", 127)
    x1.requires_grad = True

    # Initialize with noised GT
    x0[0, receptive_field:] = torch.FloatTensor(
        x0_original + np.random.normal(0, 256., x0_original.shape)).to(device)
    x1[0, receptive_field:] = torch.FloatTensor(
        x1_original + np.random.normal(0, 256., x1_original.shape)).to(device)

    sigmas = [
        175.9, 110., 68.7, 42.9, 26.8, 16.8, 10.5, 4.1, 2.56, 1.6, 1.0, 0.0
    ]
    n_steps = 10000
    start_sigma = 256.
    end_sigma = 0.1

    # Exponential annealing
    ratio = (end_sigma / start_sigma)**(1.0 / n_steps)
    sigma = start_sigma

    # Dummy start values
    curr_model_idx = -1
    curr_model_sigma = 1000000.

    for i in range(n_steps):
        # Bump down a model
        if sigma < curr_model_sigma:
            curr_model_idx += 1
            curr_model_sigma = sigmas[curr_model_idx]

            checkpoint_path0 = join(args["<checkpoint0>"],
                                    checkpoints[curr_model_sigma],
                                    "checkpoint_latest.pth")
            checkpoint_path1 = join(args["<checkpoint1>"],
                                    checkpoints[curr_model_sigma],
                                    "checkpoint_latest.pth")
            print("Load checkpoint0 from {}".format(checkpoint_path0))
            checkpoint0 = torch.load(checkpoint_path0)
            checkpoint1 = torch.load(checkpoint_path1)
            model0.load_state_dict(checkpoint0["state_dict"])
            model1.load_state_dict(checkpoint1["state_dict"])

        eta = .05 * (sigma**2)
        gamma = 15 * (1.0 / sigma)**2

        # Uncomment to see GT log likelihoods per sigma
        # x0[0, receptive_field:] = torch.FloatTensor(x0_original + np.random.normal(0, sigma, x0_original.shape)).to(device)
        # x1[0, receptive_field:] = torch.FloatTensor(x1_original + np.random.normal(0, sigma, x1_original.shape)).to(device)

        # Forward pass
        model0.zero_grad()
        log_prob, prediction0 = model0.smoothed_loss(x0, sigma=sigma)
        log_prob0 = torch.sum(log_prob[:, (receptive_field - 1):])
        # log_prob0 = torch.sum(log_prob)
        grad0 = torch.autograd.grad(log_prob0, x0)[0]

        x0_update = eta * grad0[:, receptive_field:]
        # x0_update = eta * grad0

        model1.zero_grad()
        log_prob, prediction1 = model1.smoothed_loss(x1, sigma=sigma)
        log_prob1 = torch.sum(log_prob[:, (receptive_field - 1):])
        # log_prob1 = torch.sum(log_prob)
        grad1 = torch.autograd.grad(log_prob1, x1)[0]

        x1_update = eta * grad1[:, receptive_field:]
        # x1_update = eta * grad1

        # Langevin step
        epsilon0 = np.sqrt(2 * eta) * torch.normal(
            0, 1, size=(1, SAMPLE_SIZE), device=device)
        x0_update += epsilon0

        epsilon1 = np.sqrt(2 * eta) * torch.normal(
            0, 1, size=(1, SAMPLE_SIZE), device=device)
        x1_update += epsilon1

        # Reconstruction step
        # x0_update -= eta * gamma * (x0[:, receptive_field:] + x1[:, receptive_field:] - mixed)
        # x1_update -= eta * gamma * (x0[:, receptive_field:] + x1[:, receptive_field:] - mixed)
        # x0_update -= eta * gamma * (x0 + x1 - mixed)
        # x1_update -= eta * gamma * (x0 + x1 - mixed)

        with torch.no_grad():
            x0[:, receptive_field:] += x0_update
            x1[:, receptive_field:] += x1_update

            # x0 += x0_update
            # x1 += x1_update

        if not i % 50:  # debugging
            print("--------------")
            print('sigma = {}'.format(sigma))
            print('eta = {}'.format(eta))
            print("i {}".format(i))
            print("Max sample {}".format(abs(x0).max()))
            print('Mean sample logpx: {}'.format(log_prob0 / SAMPLE_SIZE))
            print('Mean sample logpy: {}'.format(log_prob1 / SAMPLE_SIZE))
            print("Max gradient update: {}".format(eta * abs(grad0).max()))
            # print("Reconstruction: {}".format(abs(x0 + x1 - mixed).mean()))

        # Reduce sigma
        sigma *= ratio

    # out0 = P.inv_mulaw_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1)
    out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(),
                               hparams.quantize_channels - 1)
    out0 = np.clip(out0, -1, 1)
    sf.write("out0.wav", out0, hparams.sample_rate)

    out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(),
                               hparams.quantize_channels - 1)
    out1 = np.clip(out1, -1, 1)
    sf.write("out1.wav", out1, hparams.sample_rate)

    import pdb
    pdb.set_trace()
def main(args):
    model = build_model().to(device)
    model.eval()

    receptive_field = model.receptive_field
    test_data_loader = get_data_loader(args["<dump-root>"], collate_fn)

    (x, y, c, g, input_lengths) = next(iter(test_data_loader))
    # cin_pad = hparams.cin_pad
    # if cin_pad > 0:
    #     c = F.pad(c, pad=(cin_pad, cin_pad), mode="replicate")
    c = c.to(device)
    sanity_check(model, c, g)
    # Write inputs
    x_original_out = inv_linear_quantize(x, hparams.quantize_channels - 1)
    x_original_out = P.inv_mulaw_quantize(x, hparams.quantize_channels - 1)
    sf.write("x_original.wav", x_original_out[0, 0,], hparams.sample_rate)

    # Initialize with noise
    x = torch.FloatTensor(np.random.uniform(-512, 700, size=(1, x.shape[-1] + 1))).to(device)
    # x = F.pad(x, (receptive_field, 0), "constant", 127)
    x.requires_grad = True


    sigmas = [175.9, 110., 68.7,  42.9, 26.8, 16.8, 10.5, 6.55, 4.1, 2.56, 1.6, 1.0, 0.625, 0.39, 0.1]
    start_sigma = 256.
    end_sigma = 0.1

    for idx, sigma in enumerate(sigmas):
        n_steps = 200
        # Bump down a model
        checkpoint_path = join(args["<checkpoint>"], checkpoints[sigma], "checkpoint_latest.pth")
        print("Load checkpoint0 from {}".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["state_dict"])

        eta = .02 * (sigma ** 2)
        gamma = 15 * (1.0 / sigma) ** 2

        for i in range(n_steps):
            # Seed with noised up GT, good for unconditional generation
            # x0[0, :receptive_field] = torch.FloatTensor(x0_original[:receptive_field] + np.random.normal(0, sigma, x0_original[:receptive_field].shape)).to(device)
            # x1[0, :receptive_field] = torch.FloatTensor(x1_original[:receptive_field] + np.random.normal(0, sigma, x1_original[:receptive_field].shape)).to(device)

            # Seed with noised up silence
            # x0[0, :receptive_field] = torch.FloatTensor(np.random.normal(127, sigma, x0_original[:receptive_field].shape)).to(device)
            # x1[0, :receptive_field] = torch.FloatTensor(np.random.normal(127, sigma, x1_original[:receptive_field].shape)).to(device)

            # Forward pass
            log_prob, prediction = model.smoothed_loss(x, c=c, sigma=sigma)
            log_prob = torch.sum(log_prob)
            grad = torch.autograd.grad(log_prob, x)[0]
            x_update = eta * grad

            # Langevin step
            epsilon = np.sqrt(2 * eta) * torch.normal(0, 1, size=(1, x.shape[-1]), device=device)
            x_update += epsilon

            with torch.no_grad():
                x += x_update

            if (not i % 20) or (i == (n_steps - 1)): # debugging
                print("--------------")
                print('sigma = {}'.format(sigma))
                print('eta = {}'.format(eta))
                print("i {}".format(i))
                print("Max sample {}".format(
                    abs(x).max()))
                print('Mean sample logpx: {}'.format(log_prob / x.shape[-1]))
                print("Max gradient update: {}".format(eta * abs(grad).max()))

        out = P.inv_mulaw_quantize(x[0, 1:].detach().cpu().numpy(), hparams.quantize_channels - 1)
        # out = inv_linear_quantize(x[0].detach().cpu().numpy(), hparams.quantize_channels - 1)
        out = np.clip(out, -1, 1)
        sf.write("out_{}.wav".format(sigma), out, hparams.sample_rate)
Пример #9
0
def main(args):
    model0 = ModelWrapper()
    model1 = ModelWrapper()

    receptive_field = model0.receptive_field

    # Load up some GT samples
    x0_original = np.load("zf882fv0052-wave.npy")
    x0_original = x0_original[300000:300000 + SAMPLE_SIZE]

    #x1_original = np.load("p341_048-wave.npy")
    #x1_original = x1_original[32000:32000 + SAMPLE_SIZE]
    x1_original = np.load("gettysburg10-wave.npy")
    x1_original = x1_original[:SAMPLE_SIZE]

    mixed = torch.FloatTensor(x0_original + x1_original).reshape(1,
                                                                 -1).to(device)

    # Write inputs
    mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(),
                                    hparams.quantize_channels - 1) - 1.0
    mixed_out = np.clip(mixed_out, -1, 1)
    sf.write("mixed.wav", mixed_out, hparams.sample_rate)

    x0_original_out = inv_linear_quantize(x0_original,
                                          hparams.quantize_channels - 1)
    sf.write("x0_original.wav", x0_original_out, hparams.sample_rate)

    x1_original_out = inv_linear_quantize(x1_original,
                                          hparams.quantize_channels - 1)
    sf.write("x1_original.wav", x1_original_out, hparams.sample_rate)

    # Initialize with noise
    x0 = torch.FloatTensor(np.random.uniform(-512, 700,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x0 = F.pad(x0, (receptive_field, receptive_field), "constant", 127)
    x0.requires_grad = True

    x1 = torch.FloatTensor(np.random.uniform(-512, 700,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x1 = F.pad(x1, (receptive_field, receptive_field), "constant", 127)
    x1.requires_grad = True

    # Initialize with noised GT
    # x0[0, receptive_field:] = torch.FloatTensor(x0_original + np.random.normal(0, 256., x0_original.shape)).to(device)
    # x1[0, receptive_field:] = torch.FloatTensor(x1_original + np.random.normal(0, 256., x1_original.shape)).to(device)

    sigmas = [
        175.9, 110., 68.7, 54.3, 42.9, 34.0, 26.8, 21.2, 16.8, 10.5, 6.55, 4.1,
        2.56, 1.6, 1.0, 0.1
    ]
    # n_steps_each = [500, 1000, 1000, 1000, 2000, 2000, 2000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000]
    start_sigma = 256.
    end_sigma = 0.1

    np.random.seed(999)

    for idx, sigma in enumerate(sigmas):
        n_steps = int((SAMPLE_SIZE / (SGLD_WINDOW * BATCH_SIZE)) * 60)
        # Bump down a model
        checkpoint_path0 = join(args["<checkpoint0>"], checkpoints[sigma],
                                "checkpoint_latest.pth")
        model0.load_checkpoint(checkpoint_path0)
        checkpoint_path1 = join(args["<checkpoint1>"], checkpoints[sigma],
                                "checkpoint_latest.pth")
        model1.load_checkpoint(checkpoint_path1)

        parmodel0 = torch.nn.DataParallel(model0)
        parmodel0.to(device)
        parmodel1 = torch.nn.DataParallel(model1)
        parmodel1.to(device)

        eta = .05 * (sigma**2)
        gamma = 15 * (1.0 / sigma)**2

        t0 = time.time()
        for i in range(n_steps):
            # need to get a good sampling of the beginning/end (boundary effects)
            # to understand this: think about how often we would update x[receptive_field] (first point)
            # if we only sampled U(receptive_field,x0.shape-receptive_field-SGLD_WINDOW)
            j = np.random.randint(receptive_field - SGLD_WINDOW,
                                  x0.shape[1] - receptive_field, BATCH_SIZE)
            j = np.maximum(j, receptive_field)
            j = np.minimum(j, x0.shape[1] - (SGLD_WINDOW + receptive_field))

            # Seed with noised up GT, good for unconditional generation
            # x0[0, :receptive_field] = torch.FloatTensor(x0_original[:receptive_field] + np.random.normal(0, sigma, x0_original[:receptive_field].shape)).to(device)
            # x1[0, :receptive_field] = torch.FloatTensor(x1_original[:receptive_field] + np.random.normal(0, sigma, x1_original[:receptive_field].shape)).to(device)

            # Seed with noised up silence
            x0[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x0_original[:receptive_field].shape)).to(device)
            x0[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x0_original[-receptive_field:].shape)).to(device)
            x1[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x1_original[:receptive_field].shape)).to(device)
            x1[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x1_original[-receptive_field:].shape)).to(device)

            patches0 = []
            patches1 = []
            mixpatch = []
            for k in range(BATCH_SIZE):
                patches0.append(x0[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                patches1.append(x1[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                mixpatch.append(mixed[:, j[k] - receptive_field:j[k] -
                                      receptive_field + SGLD_WINDOW])

            patches0 = torch.stack(patches0, axis=0)
            patches1 = torch.stack(patches1, axis=0)
            mixpatch = torch.stack(mixpatch, axis=0)

            # Forward pass
            log_prob, prediction0 = parmodel0(patches0, sigma=sigma)
            log_prob0 = torch.sum(log_prob)
            grad0 = torch.autograd.grad(log_prob0, x0)[0]

            log_prob, prediction1 = parmodel1(patches1, sigma=sigma)
            log_prob1 = torch.sum(log_prob)
            grad1 = torch.autograd.grad(log_prob1, x1)[0]

            x0_update, x1_update = [], []
            for k in range(BATCH_SIZE):
                x0_update.append(eta * grad0[:, j[k]:j[k] + SGLD_WINDOW])
                x1_update.append(eta * grad1[:, j[k]:j[k] + SGLD_WINDOW])

            # Langevin step
            for k in range(BATCH_SIZE):
                epsilon0 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x0_update[k] += epsilon0

                epsilon1 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x1_update[k] += epsilon1

            # Reconstruction step
            for k in range(BATCH_SIZE):
                x0_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])
                x1_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])

            with torch.no_grad():
                for k in range(BATCH_SIZE):
                    x0[:, j[k]:j[k] + SGLD_WINDOW] += x0_update[k]
                    x1[:, j[k]:j[k] + SGLD_WINDOW] += x1_update[k]

            if (not i % 20) or (i == (n_steps - 1)):  # debugging
                print("--------------")
                print('sigma = {}'.format(sigma))
                print('eta = {}'.format(eta))
                print("i {}".format(i))
                print("Max sample {}".format(abs(x0).max()))
                print('Mean sample logpx: {}'.format(
                    log_prob0 / (BATCH_SIZE * SGLD_WINDOW)))
                print('Mean sample logpy: {}'.format(
                    log_prob1 / (BATCH_SIZE * SGLD_WINDOW)))
                print("Max gradient update: {}".format(eta * abs(grad0).max()))
                print("Reconstruction: {}".format(
                    abs(x0[:, receptive_field:-receptive_field] +
                        x1[:, receptive_field:-receptive_field] -
                        mixed).mean()))
                print('Elapsed time = {}'.format(time.time() - t0))
                t0 = time.time()

        # out0 = P.inv_mulaw_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1)
        out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out0 = np.clip(out0, -1, 1)
        sf.write("out0_{}.wav".format(sigma), out0, hparams.sample_rate)

        out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out1 = np.clip(out1, -1, 1)
        sf.write("out1_{}.wav".format(sigma), out1, hparams.sample_rate)
def main(args):
    model0 = build_model().to(device)
    model0.eval()

    model1 = build_model().to(device)
    model1.eval()
    receptive_field = model0.receptive_field

    # Load up some GT samples
    x0_original = np.load("zf882fv0052-wave.npy")
    x0_original = x0_original[300000:300000 + SAMPLE_SIZE]

    x1_original = np.load("p341_048-wave.npy")
    x1_original = x1_original[32000:32000 + SAMPLE_SIZE]

    mixed = torch.FloatTensor(x0_original + x1_original).reshape(1,
                                                                 -1).to(device)

    # Write inputs
    mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(),
                                    hparams.quantize_channels - 1) - 1.0
    mixed_out = np.clip(mixed_out, -1, 1)
    sf.write("mixed.wav", mixed_out, hparams.sample_rate)

    x0_original_out = inv_linear_quantize(x0_original,
                                          hparams.quantize_channels - 1)
    sf.write("x0_original.wav", x0_original_out, hparams.sample_rate)

    x1_original_out = inv_linear_quantize(x1_original,
                                          hparams.quantize_channels - 1)
    sf.write("x1_original.wav", x1_original_out, hparams.sample_rate)

    # Initialize with noise
    x0 = torch.FloatTensor(np.random.uniform(-512, 700,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x0 = F.pad(x0, (receptive_field, 0), "constant", 127)
    x0.requires_grad = True

    x1 = torch.FloatTensor(np.random.uniform(-512, 700,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x1 = F.pad(x1, (receptive_field, 0), "constant", 127)
    x1.requires_grad = True

    # Initialize with noised GT
    # x0[0, receptive_field:] = torch.FloatTensor(x0_original + np.random.normal(0, 256., x0_original.shape)).to(device)
    # x1[0, receptive_field:] = torch.FloatTensor(x1_original + np.random.normal(0, 256., x1_original.shape)).to(device)

    sigmas = [
        175.9, 110., 68.7, 54.3, 42.9, 34.0, 26.8, 21.2, 16.8, 10.5, 6.55, 4.1,
        2.56, 1.6, 1.0, 0.1
    ]
    # n_steps_each = [500, 1000, 1000, 1000, 2000, 2000, 2000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000]
    start_sigma = 256.
    end_sigma = 0.1

    for idx, sigma in enumerate(sigmas):
        n_steps = 60
        # Bump down a model
        checkpoint_path0 = join(args["<checkpoint0>"], checkpoints[sigma],
                                "checkpoint_latest.pth")
        checkpoint_path1 = join(args["<checkpoint1>"], checkpoints[sigma],
                                "checkpoint_latest.pth")
        print("Load checkpoint0 from {}".format(checkpoint_path0))
        checkpoint0 = torch.load(checkpoint_path0)
        checkpoint1 = torch.load(checkpoint_path1)
        model0.load_state_dict(checkpoint0["state_dict"])
        model1.load_state_dict(checkpoint1["state_dict"])

        eta = .05 * (sigma**2)
        gamma = 15 * (1.0 / sigma)**2

        for i in range(n_steps):
            # Seed with noised up GT, good for unconditional generation
            # x0[0, :receptive_field] = torch.FloatTensor(x0_original[:receptive_field] + np.random.normal(0, sigma, x0_original[:receptive_field].shape)).to(device)
            # x1[0, :receptive_field] = torch.FloatTensor(x1_original[:receptive_field] + np.random.normal(0, sigma, x1_original[:receptive_field].shape)).to(device)

            # Seed with noised up silence
            x0[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x0_original[:receptive_field].shape)).to(device)
            x1[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(
                    127, sigma,
                    x1_original[:receptive_field].shape)).to(device)

            # Forward pass
            log_prob, prediction0 = model0.smoothed_loss(x0, sigma=sigma)
            log_prob0 = torch.sum(log_prob[:, (receptive_field - 1):])
            grad0 = torch.autograd.grad(log_prob0, x0)[0]
            x0_update = eta * grad0[:, receptive_field:]

            log_prob, prediction1 = model1.smoothed_loss(x1, sigma=sigma)
            log_prob1 = torch.sum(log_prob[:, (receptive_field - 1):])
            grad1 = torch.autograd.grad(log_prob1, x1)[0]
            x1_update = eta * grad1[:, receptive_field:]

            # Langevin step
            epsilon0 = np.sqrt(2 * eta) * torch.normal(
                0, 1, size=(1, SAMPLE_SIZE), device=device)
            x0_update += epsilon0

            epsilon1 = np.sqrt(2 * eta) * torch.normal(
                0, 1, size=(1, SAMPLE_SIZE), device=device)
            x1_update += epsilon1

            # Reconstruction step
            x0_update -= eta * gamma * (x0[:, receptive_field:] +
                                        x1[:, receptive_field:] - mixed)
            x1_update -= eta * gamma * (x0[:, receptive_field:] +
                                        x1[:, receptive_field:] - mixed)

            with torch.no_grad():
                x0[:, receptive_field:] += x0_update
                x1[:, receptive_field:] += x1_update

            if (not i % 20) or (i == (n_steps - 1)):  # debugging
                print("--------------")
                print('sigma = {}'.format(sigma))
                print('eta = {}'.format(eta))
                print("i {}".format(i))
                print("Max sample {}".format(abs(x0).max()))
                print('Mean sample logpx: {}'.format(log_prob0 / SAMPLE_SIZE))
                print('Mean sample logpy: {}'.format(log_prob1 / SAMPLE_SIZE))
                print("Max gradient update: {}".format(eta * abs(grad0).max()))
                print("Reconstruction: {}".format(
                    abs(x0[:, receptive_field:] + x1[:, receptive_field:] -
                        mixed).mean()))

        # out0 = P.inv_mulaw_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1)
        out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out0 = np.clip(out0, -1, 1)
        sf.write("out0_{}.wav".format(sigma), out0, hparams.sample_rate)

        out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out1 = np.clip(out1, -1, 1)
        sf.write("out1_{}.wav".format(sigma), out1, hparams.sample_rate)

    import pdb
    pdb.set_trace()