Exemple #1
0
def test_generate():
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 32)
    length = h.shape[-1] - 1
    with torch.no_grad():
        net = WaveNet(256, 28, 16, 32, 10, 3, 2)
        net.apply(initialize)
        net.eval()
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            net.generate(batch_x, batch_h, length, 1, "sampling")
            net.fast_generate(batch_x, batch_h, length, 1, "sampling")
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "sampling")
Exemple #2
0
def main(args):
    print('Starting')
    matplotlib.use('agg')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(args.model.parent / 'args.pth')[0]
    encoder = wavenet_models.Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if args.py:
            decoder = WavenetGenerator(decoder,
                                       args.batch_size,
                                       wav_freq=args.rate)
        else:
            decoder = NVWavenetGenerator(decoder,
                                         args.rate * (args.split_size // 20),
                                         args.batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert args.output_next_to_orig ^ (args.output is not None)

    if len(args.files) == 1 and args.files[0].is_dir():
        top = args.files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = args.files

    if not args.skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = utils.mu_law(data)
        elif file_path.suffix == '.h5':
            data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % args.rate != 0:
                data = data[:-(data.shape[-1] % args.rate)]
            assert data.shape[-1] % args.rate == 0
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if args.sample_len:
            data = data[:args.sample_len]
        else:
            args.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = utils.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if args.output_next_to_orig:
            save_audio(wav.squeeze(),
                       filepath.parent / f'{filepath.stem}_{decoder_ix}.wav',
                       rate=args.rate)
        else:
            save_audio(wav.squeeze(),
                       args.output / str(extract_id(args.model)) /
                       str(args.update) / filepath.with_suffix('.wav').name,
                       rate=args.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, args.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with utils.timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, args.batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, args.split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)
def test_assert_fast_generation():
    # get batch
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 32)
    length = h.shape[-1] - 1

    with torch.no_grad():
        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 2
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 3
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 3)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # get batch
        batch = 2
        upsampling_factor = 10
        x = np.random.randint(0, 256, size=(batch, 1))
        h = np.random.randn(batch, 28, 3)
        length = h.shape[-1] * upsampling_factor - 1

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 2
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 3
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)
Exemple #4
0
        x = batch[:-1]
        logits = net(x)
        sz = logits.size(0)
        loss = loss + nn.functional.cross_entropy(logits, batch[-sz:])
    loss = loss / batch_size
    loss.backward()
    optimizer.step()
    loss_save.append(loss.data[0])
    # monitor progress
    if epoch % 100 == 0:
        batch = next(g)
        print('epoch %i, loss %.4f' % (epoch, loss.data[0]))
        logits = net(batch[:-1])
        _, i = logits.max(dim=1)
        plt.figure(figsize=[16, 4])
        plt.plot(i.data.tolist())
        plt.plot(batch.data.tolist()[sum(net.dilations) + 1:])
        plt.legend(['generated', 'data'])
        plt.title('epoch %i' % epoch)
        plt.tight_layout()
        plt.savefig('train/epoch_%i.png' % epoch)

y_gen = net.generate(batch, 4000)
plt.figure(figsize=[16, 4])
plt.plot(y_gen, '--', c='b')
plt.plot(batch.data.tolist(), ms=2, c='k')
plt.legend(['generated', 'data'])
plt.savefig('train/generation.png')

torch.save(net.state_dict(), 'train/wavenet.pt')
Exemple #5
0
def main():
    args = get_arguments()
    # load audio file to use as source of aux features
    if args.use_aux_features:
        feature_source, _ = librosa.load(args.aux_source,
                                         sr=SAMPLE_RATE,
                                         mono=False)
        # if stereo file, take left channel
        if len(feature_source.shape) > 1:
            feature_source = feature_source[0]
        melcep = calculate_mel_cepstrum(feature_source,
                                        sample_rate=SAMPLE_RATE,
                                        hop_length=args.hop_length,
                                        n_mfcc=args.n_mfcc,
                                        n_fft=args.n_fft,
                                        mean_sub=args.mean_sub,
                                        normalise=args.normalise,
                                        mean_file=args.mean_file,
                                        std_file=args.std_file)

    sess = tf.Session()

    wavenet = WaveNet(dilations=DILATIONS,
                      residual_channels=RESIDUAL_CHANNELS,
                      dilation_channels=DILATION_CHANNELS,
                      quantization_channels=QUANTIZATION_CHANNELS,
                      skip_channels=SKIP_CHANNELS,
                      use_aux_features=args.use_aux_features,
                      n_mfcc=args.n_mfcc)

    # placeholders for samples and aux input
    if args.use_aux_features:
        aux_input = tf.placeholder(dtype=tf.float32,
                                   shape=(1, None, args.n_mfcc))
        # make samples generated equal to length of aux source
        if melcep.shape[0] < args.samples + wavenet.receptive_field:
            args.samples = melcep.shape[0] - wavenet.receptive_field
    else:
        aux_input = None
    samples = tf.placeholder(tf.int32)

    output = wavenet.generate(samples,
                              aux_input,
                              use_aux_features=args.use_aux_features)

    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    # make seed waveform
    waveform = [QUANTIZATION_CHANNELS / 2] * (wavenet.receptive_field - 1)
    waveform.append(np.random.randint(QUANTIZATION_CHANNELS))

    last_step_print = datetime.now()
    # generate waveform
    for step in range(args.samples):
        if len(waveform) > wavenet.receptive_field:
            window = waveform[-wavenet.receptive_field:]
        else:
            window = waveform

        if args.use_aux_features:
            window_aux = melcep[step:step + wavenet.receptive_field]
            window_aux = np.array([window_aux])
            prediction = sess.run([output],
                                  feed_dict={
                                      samples: window,
                                      aux_input: window_aux
                                  })[0]
        else:
            prediction = sess.run([output], feed_dict={samples: window})[0]

        sample = np.random.choice(np.arange(QUANTIZATION_CHANNELS),
                                  p=prediction)
        waveform.append(sample)

        # print progress each second
        now = datetime.now()
        time_since_print = now - last_step_print
        if time_since_print.total_seconds() > 1.:
            print("Completed samples:", step, "of", args.samples)
            last_step_print = now

    # Decode samples and save as wave file
    decode_in = tf.placeholder(dtype=tf.float32)
    mu = QUANTIZATION_CHANNELS - 1
    signal = 2 * (tf.to_float(decode_in) / mu) - 1
    magnitude = (1 / mu) * ((1 + mu)**abs(signal) - 1)
    decoded = tf.sign(signal) * magnitude

    out = sess.run(decoded, feed_dict={decode_in: waveform})
    librosa.output.write_wav(args.wav_out_path, out, SAMPLE_RATE)
    print("Sound file generated at", args.wav_out_path)