def get_voice_file(idx, duration, quantize_type):
    """
    Gets one of the last VCTK voices
    """
    BASE_PATH = "/projects/grail/audiovisual/datasets/VCTK-Corpus/wav48/test"

    assert (idx in list(range(0, 100)))
    if idx < 25:
        speaker_path = os.path.join(BASE_PATH, "p345")
    elif idx < 50:
        speaker_path = os.path.join(BASE_PATH, "p361")
    elif idx < 75:
        speaker_path = os.path.join(BASE_PATH, "p362")
    elif idx < 100:
        speaker_path = os.path.join(BASE_PATH, "p374")

    file_list = list(Path(speaker_path).rglob('*.wav'))

    curr_file = random.choice(file_list)
    y, sr = librosa.core.load(curr_file, sr=22050)

    y /= abs(y).max()
    start_idx = len(y) // 2
    y = y[int(start_idx - duration / 2):int(start_idx + duration / 2)]

    # Mulaw, linear or linear max audio
    if quantize_type == 0:
        quantized = P.mulaw_quantize(y, hparams.quantize_channels - 1)

    elif quantize_type == 1:
        quantized = linear_quantize(y, hparams.quantize_channels - 1)

    return quantized
def get_piano_file(idx, duration, quantize_type):
    """
    Gets one of the test supra piano samples
    """
    BASE_PATH = "/projects/grail/audiovisual/datasets/supra-rw-mp3/test"

    file_list = list(Path(BASE_PATH).rglob("*.mp3"))
    curr_file = random.choice(file_list)
    y, sr = librosa.core.load(curr_file, sr=22050)
    y /= abs(y).max()

    num_samples = y.shape[0]
    start_idx = random.randint(0, num_samples - duration)
    y = y[start_idx:start_idx + duration]

    # Mulaw, linear or linear max audio
    if quantize_type == 0:
        quantized = P.mulaw_quantize(y, hparams.quantize_channels - 1)

    elif quantize_type == 1:
        quantized = linear_quantize(y, hparams.quantize_channels - 1)

    return quantized
Example #3
0
def main(args):
    model0 = ModelWrapper()
    model1 = ModelWrapper()

    receptive_field = model0.receptive_field

    writing_dir = args["<output-dir>"]
    os.makedirs(writing_dir, exist_ok=True)
    print("writing dir: {}".format(writing_dir))

    source1 = librosa.core.load(args["<input-file1>"], sr=22050, mono=True)[0]
    source2 = librosa.core.load(args["<input-file2>"], sr=22050, mono=True)[0]
    mixed = source1 + source2

    # Increase the volume of the mixture fo avoid artifacts from linear encoding
    mixed /= abs(mixed).max()
    mixed *= 1.4
    mixed = linear_quantize(mixed + 1.0, hparams.quantize_channels - 1)
    global SAMPLE_SIZE
    if SAMPLE_SIZE == -1:
        SAMPLE_SIZE = int(mixed.shape[0])

    mixed = mixed[:SAMPLE_SIZE]

    mixed = torch.FloatTensor(mixed).reshape(1, -1).to(device)

    # Write inputs
    mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(),
                                    hparams.quantize_channels - 1) - 1.0
    mixed_out = np.clip(mixed_out, -1, 1)
    sf.write(join(writing_dir, "mixed.wav"), mixed_out, hparams.sample_rate)

    # Initialize with noise
    x0 = torch.FloatTensor(np.random.uniform(0, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x0[:] = mixed - 127.0
    x0 = F.pad(x0, (receptive_field, receptive_field), "constant", 127)
    x0.requires_grad = True

    x1 = torch.FloatTensor(np.random.uniform(0, 512,
                                             size=(1, SAMPLE_SIZE))).to(device)
    x1[:] = 127.
    x1 = F.pad(x1, (receptive_field, receptive_field), "constant", 127)
    x1.requires_grad = True

    sigmas = [
        175.9, 110., 68.7, 54.3, 42.9, 34.0, 26.8, 21.2, 16.8, 13.3, 10.5,
        8.29, 6.55, 5.18, 4.1, 3.24, 2.56, 1.6, 1.0, 0.625, 0.39, 0.244, 0.15,
        0.1
    ]

    np.random.seed(999)

    for idx, sigma in enumerate(sigmas):
        # We make sure each sample is updated a certain number of times
        n_steps = int((SAMPLE_SIZE / (SGLD_WINDOW * BATCH_SIZE)) * N_STEPS)
        print("Number of SGLD steps {}".format(n_steps))
        # Bump down a model
        checkpoint_path0 = join(args["<checkpoint0>"], CHECKPOINTS[sigma],
                                "checkpoint_latest_ema.pth")
        model0.load_checkpoint(checkpoint_path0)
        checkpoint_path1 = join(args["<checkpoint1>"], CHECKPOINTS[sigma],
                                "checkpoint_latest_ema.pth")
        model1.load_checkpoint(checkpoint_path1)

        parmodel0 = torch.nn.DataParallel(model0)
        parmodel0.to(device)
        parmodel1 = torch.nn.DataParallel(model1)
        parmodel1.to(device)

        eta = .05 * (sigma**2)
        gamma = 15 * (1.0 / sigma)**2

        t0 = time.time()
        for i in range(n_steps):
            # need to get a good sampling of the beginning/end (boundary effects)
            # to understand this: think about how often we would update x[receptive_field] (first point)
            # if we only sampled U(receptive_field,x0.shape-receptive_field-SGLD_WINDOW)
            j = np.random.randint(receptive_field - SGLD_WINDOW,
                                  x0.shape[1] - receptive_field, BATCH_SIZE)
            j = np.maximum(j, receptive_field)
            j = np.minimum(j, x0.shape[1] - (SGLD_WINDOW + receptive_field))

            # Seed with noised up silence
            x0[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, :receptive_field].shape)).to(device)
            x0[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, -receptive_field:].shape)).to(device)
            x1[0, :receptive_field] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, :receptive_field].shape)).to(device)
            x1[0, -receptive_field:] = torch.FloatTensor(
                np.random.normal(127, sigma,
                                 mixed[0, -receptive_field:].shape)).to(device)

            patches0 = []
            patches1 = []
            mixpatch = []
            for k in range(BATCH_SIZE):
                patches0.append(x0[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                patches1.append(x1[:, j[k] - receptive_field:j[k] +
                                   SGLD_WINDOW + receptive_field])
                mixpatch.append(mixed[:, j[k] - receptive_field:j[k] -
                                      receptive_field + SGLD_WINDOW])

            patches0 = torch.stack(patches0, axis=0)
            patches1 = torch.stack(patches1, axis=0)
            mixpatch = torch.stack(mixpatch, axis=0)

            # Forward pass
            log_prob, prediction0 = parmodel0(patches0, sigma=sigma)
            log_prob0 = torch.sum(log_prob)
            grad0 = torch.autograd.grad(log_prob0, x0)[0]

            log_prob, prediction1 = parmodel1(patches1, sigma=sigma)
            log_prob1 = torch.sum(log_prob)
            grad1 = torch.autograd.grad(log_prob1, x1)[0]

            x0_update, x1_update = [], []
            for k in range(BATCH_SIZE):
                x0_update.append(eta * grad0[:, j[k]:j[k] + SGLD_WINDOW])
                x1_update.append(eta * grad1[:, j[k]:j[k] + SGLD_WINDOW])

            # Langevin step
            for k in range(BATCH_SIZE):
                epsilon0 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x0_update[k] += epsilon0

                epsilon1 = np.sqrt(2 * eta) * torch.normal(
                    0, 1, size=(1, SGLD_WINDOW), device=device)
                x1_update[k] += epsilon1

            # Reconstruction step
            for k in range(BATCH_SIZE):
                x0_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])
                x1_update[k] -= eta * gamma * (
                    patches0[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] +
                    patches1[k][:, receptive_field:receptive_field +
                                SGLD_WINDOW] - mixpatch[k])

            with torch.no_grad():
                for k in range(BATCH_SIZE):
                    x0[:, j[k]:j[k] + SGLD_WINDOW] += x0_update[k]
                    x1[:, j[k]:j[k] + SGLD_WINDOW] += x1_update[k]

            if (not i % 40) or (i == (n_steps - 1)):  # debugging
                print("--------------")
                print('sigma = {}'.format(sigma))
                print('eta = {}'.format(eta))
                print("i {}".format(i))
                print("Max sample {}".format(abs(x0).max()))
                print('Mean sample logpx: {}'.format(
                    log_prob0 / (BATCH_SIZE * SGLD_WINDOW)))
                print('Mean sample logpy: {}'.format(
                    log_prob1 / (BATCH_SIZE * SGLD_WINDOW)))
                print("Max gradient update: {}".format(eta * abs(grad0).max()))
                print("Reconstruction: {}".format(
                    abs(x0[:, receptive_field:-receptive_field] +
                        x1[:, receptive_field:-receptive_field] -
                        mixed).mean()))
                print('Elapsed time = {}'.format(time.time() - t0))
                t0 = time.time()

        out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out0 = np.clip(out0, -1, 1)
        sf.write(join(writing_dir, "out0_{}.wav".format(sigma)), out0,
                 hparams.sample_rate)

        out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(),
                                   hparams.quantize_channels - 1)
        out1 = np.clip(out1, -1, 1)
        sf.write(join(writing_dir, "out1_{}.wav".format(sigma)), out1,
                 hparams.sample_rate)
Example #4
0
def _process_utterance(out_dir, index, wav_path, text, no_mel):
    # Load the audio to a numpy array:
    wav = audio.load_wav(wav_path)

    # Trim begin/end silences
    # NOTE: the threshold was chosen for clean signals
    wav, _ = librosa.effects.trim(wav,
                                  top_db=60,
                                  frame_length=2048,
                                  hop_length=512)

    if hparams.highpass_cutoff > 0.0:
        wav = audio.low_cut_filter(wav, hparams.sample_rate,
                                   hparams.highpass_cutoff)

    # Mu-law quantize
    if is_mulaw_quantize(hparams.input_type):
        # Trim silences in mul-aw quantized domain
        silence_threshold = 0
        if silence_threshold > 0:
            # [0, quantize_channels)
            out = P.mulaw_quantize(wav, hparams.quantize_channels - 1)
            start, end = audio.start_and_end_indices(out, silence_threshold)
            wav = wav[start:end]
        constant_values = P.mulaw_quantize(0, hparams.quantize_channels - 1)
        out_dtype = np.int16
    elif is_linear_quantize(hparams.input_type):
        # Trim silences in linear quantized domain
        silence_threshold = 0
        if silence_threshold > 0:
            # [0, quantize_channels)
            out = linear_quantize(wav, hparams.quantize_channels - 1)
            start, end = audio.start_and_end_indices(out, silence_threshold)
            wav = wav[start:end]
        constant_values = linear_quantize(0, hparams.quantize_channels - 1)
        out_dtype = np.int16
    elif is_mulaw(hparams.input_type):
        # [-1, 1]
        constant_values = P.mulaw(0.0, hparams.quantize_channels - 1)
        out_dtype = np.float32
    else:
        # [-1, 1]
        constant_values = 0.0
        out_dtype = np.float32

    if hparams.global_gain_scale > 0:
        wav *= hparams.global_gain_scale

    if hparams.normalize_max_audio:
        wav /= abs(wav).max()

    # Compute a mel-scale spectrogram from the trimmed wav:
    # (N, D)
    if not no_mel:
        mel_spectrogram = audio.logmelspectrogram(wav).astype(np.float32).T

    # Time domain preprocessing
    if hparams.preprocess is not None and hparams.preprocess not in [
            "", "none"
    ]:
        f = getattr(audio, hparams.preprocess)
        wav = f(wav)

    # Clip
    if np.abs(wav).max() > 1.0:
        print("""Warning: abs max value exceeds 1.0: {}""".format(
            np.abs(wav).max()))
        # ignore this sample
        # return ("dummy", "dummy", -1, "dummy")

    wav = np.clip(wav, -1.0, 1.0)

    # Set waveform target (out)
    if is_mulaw_quantize(hparams.input_type):
        out = P.mulaw_quantize(wav, hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        out = linear_quantize(wav, hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        out = P.mulaw(wav, hparams.quantize_channels - 1)
    else:
        out = wav

    # zero pad
    # this is needed to adjust time resolution between audio and mel-spectrogram
    if not no_mel:
        l, r = audio.pad_lr(out, hparams.fft_size, audio.get_hop_size())
        if l > 0 or r > 0:
            out = np.pad(out, (l, r),
                         mode="constant",
                         constant_values=constant_values)
        N = mel_spectrogram.shape[0]
        assert len(out) >= N * audio.get_hop_size()

        # time resolution adjustment
        # ensure length of raw audio is multiple of hop_size so that we can use
        # transposed convolution to upsample
        out = out[:N * audio.get_hop_size()]
        assert len(out) % audio.get_hop_size() == 0

    # Write the spectrograms to disk:
    name = splitext(basename(wav_path))[0]
    audio_filename = '{}-{}-wave.npy'.format(name, index)
    np.save(os.path.join(out_dir, audio_filename),
            out.astype(out_dtype),
            allow_pickle=False)
    if not no_mel:
        mel_filename = '{}-{}-feats.npy'.format(name, index)
        np.save(os.path.join(out_dir, mel_filename),
                mel_spectrogram.astype(np.float32),
                allow_pickle=False)
    else:
        mel_filename = ""

    # Return a tuple describing this training example:
    return (audio_filename, mel_filename, N, text)
Example #5
0
def eval_model(global_step,
               writer,
               device,
               model,
               y,
               c,
               g,
               input_lengths,
               eval_dir,
               ema=None):
    if ema is not None:
        print("Using averaged model for evaluation")
        model = clone_as_averaged_model(device, model, ema)
        model.make_generation_fast_()

    model.eval()
    idx = np.random.randint(0, len(y))
    length = input_lengths[idx].data.cpu().item()

    # (T,)
    y_target = y[idx].view(-1).data.cpu().numpy()[:length]

    if c is not None:
        if hparams.upsample_conditional_features:
            c = c[idx, :, :length // audio.get_hop_size() +
                  hparams.cin_pad * 2].unsqueeze(0)
        else:
            c = c[idx, :, :length].unsqueeze(0)
        assert c.dim() == 3
        print("Shape of local conditioning features: {}".format(c.size()))
    if g is not None:
        # TODO: test
        g = g[idx]
        print("Shape of global conditioning features: {}".format(g.size()))

    # Dummy silence
    if is_mulaw_quantize(hparams.input_type):
        initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        initial_value = linear_quantize(0, hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        initial_value = P.mulaw(0.0, hparams.quantize_channels)
    else:
        initial_value = 0.0

    # (C,)
    if (is_mulaw_quantize(hparams.input_type) or is_linear_quantize(
            hparams.input_type)) and not hparams.manual_scalar_input:
        initial_input = to_categorical(
            initial_value,
            num_classes=hparams.quantize_channels).astype(np.float32)
        initial_input = torch.from_numpy(initial_input).view(
            1, 1, hparams.quantize_channels)
    else:
        initial_input = torch.zeros(1, 1, 1).fill_(initial_value)
    initial_input = initial_input.to(device)

    # Run the model in fast eval mode
    with torch.no_grad():
        y_hat = model.incremental_forward(initial_input,
                                          c=c,
                                          g=g,
                                          T=length,
                                          softmax=True,
                                          quantize=True,
                                          tqdm=tqdm,
                                          log_scale_min=hparams.log_scale_min)

    if is_mulaw_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy()
        y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1)
        y_target = P.inv_mulaw_quantize(y_target,
                                        hparams.quantize_channels - 1)
    elif is_linear_quantize(hparams.input_type):
        y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy()
        y_hat = inv_linear_quantize(y_hat, hparams.quantize_channels - 1)
        y_target = inv_linear_quantize(y_target, hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = P.inv_mulaw(
            y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels)
        y_target = P.inv_mulaw(y_target, hparams.quantize_channels)
    else:
        y_hat = y_hat.view(-1).cpu().data.numpy()

    # Save audio
    os.makedirs(eval_dir, exist_ok=True)
    path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step))
    librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
    path = join(eval_dir, "step{:09d}_target.wav".format(global_step))
    librosa.output.write_wav(path, y_target, sr=hparams.sample_rate)

    # save figure
    path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step))
    save_waveplot(path, y_hat, y_target)
Example #6
0
def collate_fn(batch):
    """Create batch

    Args:
        batch(tuple): List of tuples
            - x[0] (ndarray,int) : list of (T,)
            - x[1] (ndarray,int) : list of (T, D)
            - x[2] (ndarray,int) : list of (1,), speaker id
    Returns:
        tuple: Tuple of batch
            - x (FloatTensor) : Network inputs (B, C, T)
            - y (LongTensor)  : Network targets (B, T, 1)
    """
    local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0
    global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0

    if hparams.max_time_sec is not None:
        max_time_steps = int(hparams.max_time_sec * hparams.sample_rate)
    elif hparams.max_time_steps is not None:
        max_time_steps = hparams.max_time_steps
    else:
        max_time_steps = None

    # Time resolution adjustment
    cin_pad = hparams.cin_pad
    if local_conditioning:
        new_batch = []
        for idx in range(len(batch)):
            x, c, g = batch[idx]
            if hparams.upsample_conditional_features:
                assert_ready_for_upsampling(x, c, cin_pad=0)
                if max_time_steps is not None:
                    max_steps = ensure_divisible(max_time_steps,
                                                 audio.get_hop_size(), True)
                    if len(x) > max_steps:
                        max_time_frames = max_steps // audio.get_hop_size()
                        s = np.random.randint(
                            cin_pad,
                            len(c) - max_time_frames - cin_pad)
                        ts = s * audio.get_hop_size()
                        x = x[ts:ts + audio.get_hop_size() * max_time_frames]
                        c = c[s - cin_pad:s + max_time_frames + cin_pad, :]
                        assert_ready_for_upsampling(x, c, cin_pad=cin_pad)
            else:
                x, c = audio.adjust_time_resolution(x, c)
                if max_time_steps is not None and len(x) > max_time_steps:
                    s = np.random.randint(cin_pad,
                                          len(x) - max_time_steps - cin_pad)
                    x = x[s:s + max_time_steps]
                    c = c[s - cin_pad:s + max_time_steps + cin_pad, :]
                assert len(x) == len(c)
            new_batch.append((x, c, g))
        batch = new_batch
    else:
        new_batch = []
        for idx in range(len(batch)):
            x, c, g = batch[idx]
            x = audio.trim(x)
            if max_time_steps is not None and len(x) > max_time_steps:
                s = np.random.randint(0, len(x) - max_time_steps)
                if local_conditioning:
                    x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :]
                else:
                    x = x[s:s + max_time_steps]
            new_batch.append((x, c, g))
        batch = new_batch

    # Lengths
    input_lengths = [len(x[0]) for x in batch]
    max_input_len = max(input_lengths)

    # (B, T, C)
    # pad for time-axis
    if is_mulaw_quantize(hparams.input_type):
        padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)

        if hparams.manual_scalar_input:
            x_batch = np.array([
                _pad_2d(x[0].reshape(-1, 1), max_input_len, 0, padding_value)
                for x in batch
            ],
                               dtype=np.float32)
        else:
            x_batch = np.array([
                _pad_2d(
                    to_categorical(x[0],
                                   num_classes=hparams.quantize_channels),
                    max_input_len, 0, padding_value) for x in batch
            ],
                               dtype=np.float32)
    elif is_linear_quantize(hparams.input_type):
        padding_value = linear_quantize(0, hparams.quantize_channels - 1)

        if hparams.manual_scalar_input:
            x_batch = np.array([
                _pad_2d(x[0].reshape(-1, 1), max_input_len, 0, padding_value)
                for x in batch
            ],
                               dtype=np.float32)
        else:
            x_batch = np.array([
                _pad_2d(
                    to_categorical(x[0],
                                   num_classes=hparams.quantize_channels),
                    max_input_len, 0, padding_value) for x in batch
            ],
                               dtype=np.float32)

    else:
        x_batch = np.array(
            [_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch],
            dtype=np.float32)
    assert len(x_batch.shape) == 3

    # (B, T)
    if is_mulaw_quantize(hparams.input_type):
        padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)
        y_batch = np.array([
            _pad(x[0], max_input_len, constant_values=padding_value)
            for x in batch
        ],
                           dtype=np.int)
    elif is_linear_quantize(hparams.input_type):
        padding_value = linear_quantize(0, hparams.quantize_channels - 1)
        y_batch = np.array([
            _pad(x[0], max_input_len, constant_values=padding_value)
            for x in batch
        ],
                           dtype=np.int)
    else:
        y_batch = np.array([_pad(x[0], max_input_len) for x in batch],
                           dtype=np.float32)
    assert len(y_batch.shape) == 2

    # (B, T, D)
    if local_conditioning:
        max_len = max([len(x[1]) for x in batch])
        c_batch = np.array([_pad_2d(x[1], max_len) for x in batch],
                           dtype=np.float32)
        assert len(c_batch.shape) == 3
        # (B x C x T)
        c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous()
    else:
        c_batch = None

    if global_conditioning:
        g_batch = torch.LongTensor([x[2] for x in batch])
    else:
        g_batch = None

    # Covnert to channel first i.e., (B, C, T)
    x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous()
    # Add extra axis
    if is_mulaw_quantize(hparams.input_type) or is_linear_quantize(
            hparams.input_type):
        y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous()
    else:
        y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous()

    input_lengths = torch.LongTensor(input_lengths)

    return x_batch, y_batch, c_batch, g_batch, input_lengths