コード例 #1
0
def test_tracks():

    with helpers.tempdir() as cache:
        p = model.HParams(sample_length=16000, sample_overlap_length=0)
        root = Path("fixtures/goldberg")
        ds = datasets.Tracks.from_dir(p, root, cache_dir=cache)
        cache = cache / p.audio_cache_key()
        assert set(ds.tracks) == set([
            datasets.TrackMeta(root, cache, Path("goldberg.wav"), 1507200),
            datasets.TrackMeta(root, cache, Path("short.wav"), 150431),
            datasets.TrackMeta(root, cache, Path("aria.wav"), 4792320),
        ])

        # spot check that short is actually as short as claimed
        path, duration = root / "short.wav", 150432  # slight error.
        y = audio.load_resampled(path, p)
        _, n_samples = y.shape
        assert duration == n_samples

        # check that the ds length is consistent with meta durations
        got_duration = len(ds) * p.sample_length
        want_duration = sum(
            [audio.prune_duration(t.duration, p) for t in ds.tracks])
        assert got_duration == want_duration

        # check that we can retrieve first and last examples.
        # expensive due to the resampling step.
        for i in [0, len(ds) - 1]:
            x, y, meta = ds[i]
            assert x is not None
            assert x.shape == (2, p.sample_length)
            assert y is not None
            assert y.shape == (2, p.sample_length)
            assert meta is not None
コード例 #2
0
def test_sines_dataset():
    d = datasets.Sines(4, model.HParams())
    x, y = d[0]
    assert y.shape == (2, 16000)  # stereo
    assert y.shape == (2, 16000)  # stereo
    assert len(d) == 4
    assert repr(d) == "Sines(nseconds: 1.0)"
コード例 #3
0
def test_generator_forward_one_sample():
    m = model.Wavenet(model.HParams(n_classes=2**8))
    g = sample.Generator(m)
    y = torch.randint(0, 2**8, (2, 2, 1))
    x = y.float()
    y, loss = g.forward(x, y)
    assert y.shape == (2, m.cfg.n_classes, m.cfg.n_audio_chans, 1)
コード例 #4
0
def test_wavenet_batchnorm_output_shape(embed_inputs):
    p = model.HParams(embed_inputs=embed_inputs, batch_norm=True)
    m = model.Wavenet(p)
    y = torch.randint(p.n_classes, (3, p.n_audio_chans, 4))
    x = y.float()
    y_hat, _ = m.forward(x, y)
    assert y_hat.shape == (3, p.n_classes, p.n_audio_chans, 4)
コード例 #5
0
def test_distributed_data_parallel():
    p = model.HParams(n_audio_chans=2, n_layers=8).with_all_chans(2)
    tp = train.HParams(max_epochs=1, batch_size=8)
    ds, ds_test = datasets.tracks("fixtures/goldberg/short.wav", 0.8, p)
    m = model.Wavenet(p)
    t = distributed.DDP(m, ds, None, tp)
    t.train()
コード例 #6
0
def test_one_logit_generator_vs_wavenet(batch_norm):
    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        n_chans=32,
        dilation_stacks=1,
        n_layers=6,
        batch_norm=batch_norm,
        compress=False,
    )

    m = model.Wavenet(p).eval()
    g = sample.Generator(m).eval()
    x = torch.rand((10, m.cfg.n_audio_chans, 1))  # don't bn with batch size 1

    # forward pass through original with a single example
    ym, _ = m.forward(x)
    ym = F.softmax(ym.squeeze(), dim=0)

    # forward pass copy with a single example
    yg, _ = g.forward(F.pad(x, (1, -1)))  # causal
    yg = F.softmax(yg.squeeze(), dim=0)

    assert torch.all(ym == yg)
コード例 #7
0
def test_track_uncompressed():
    sr = 16000
    p = model.HParams(compress=False, sample_overlap_length=sr - 2**13)
    d = datasets.Track("fixtures/goldberg/short.wav", p)
    x, y = d[0]
    assert y.shape == (2, sr)
    assert len(d) == 17
    assert repr(d) == "Track(fixtures/goldberg/short.wav)"
コード例 #8
0
def test_many_logits_generator_vs_wavenet(batch_norm, n_samples):
    """This test doesn't do any sampling. Instead we compare logits. On the
    wavenet, this can be done with a single forward pass. The generator only
    accepts a one sample input, so we have to generate that by passing through
    one piece of the input at a time while appending all the outputs.
    """

    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        dilation_stacks=1,
        n_layers=1,
        batch_norm=batch_norm,
        compress=False,
    ).with_all_chans(32)

    utils.seed(p)  # reset seeds and use deterministic mode

    # set up model and generator
    m = model.Wavenet(p).eval()

    def jiggle(m):
        if isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1.1)
            m.bias.data.fill_(0.1)
            m.running_var.fill_(1.1)
            m.running_mean.fill_(0.1)

    if p.batch_norm:
        # Make sure that we have something different to vanilla init. Negative
        # test will otherwise not fail
        m.apply(jiggle)

    # a single forward pass through the wavenet.
    x = torch.rand((1, m.cfg.n_audio_chans, n_samples))
    ym, _ = m.forward(x)

    # iterate forward on generator to accumulate all of the logits that would
    # have been output on a single forward pass of a random input. that's what
    # we see in the line above, for the underlying wavenet. the generator,
    # however, only processes a single sample at a time.
    g = sample.Generator(m).eval()
    yg = None
    x = F.pad(x, (1, -1))
    for i in range(n_samples):
        timestep = x[:, :, i:(i + 1)]
        logits, _ = g.forward(timestep)
        if yg is not None:
            yg = torch.cat([yg, logits], -1)
        else:
            yg = logits

    # posterior
    ym = F.softmax(ym.squeeze(), dim=0)
    yg = F.softmax(yg.squeeze(), dim=0)

    assert torch.allclose(ym, yg)
コード例 #9
0
def test_track_dataset_stacked():
    sr = 16000
    p = model.HParams(sample_overlap_length=sr - 2**13)
    d = datasets.Track("fixtures/goldberg/short.wav", p)
    x, y = datasets.to_tensor(d)
    assert x.shape == (17, 2, sr)
    assert y.shape == (17, 2, sr)
    assert torch.min(y) >= 0.0
    assert torch.max(y) <= 256.0
コード例 #10
0
def test_sines_fixed_phase_dataset():
    d = datasets.Sines(4, model.HParams(), phase=0.0)
    x, y = d[0]
    assert x.shape == (2, 16000)  # stereo
    assert y.shape == (2, 16000)  # stereo
    assert d.amp.shape == (4, )  # one amp per example
    assert d.hz.shape == (4, )  # one hz per example
    assert isinstance(d.phase, float)  # one phase for all examples
    assert len(d) == 4
    assert repr(d) == "Sines(nseconds: 1.0, phase: 0.0)"
コード例 #11
0
def test_many_logits_fast_vs_simple(embed_inputs, n_audio_chans, batch_norm):

    n_samples, n_examples = 100, 5
    p = model.HParams(
        mixed_precision=False,
        embed_inputs=embed_inputs,
        n_audio_chans=n_audio_chans,
        n_classes=20,
        dilation_stacks=1,
        n_layers=2,
        batch_norm=batch_norm,
        compress=False,
        sample_length=n_samples,
        seed=135,
    ).with_all_chans(16)

    utils.seed(p)
    ds = datasets.Tiny(n_samples, n_examples)
    m = model.Wavenet(p).eval()

    def decoder(logits):
        utils.seed(p)
        return utils.decode_random(logits)

    def jiggle(m):
        if isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1.1)
            m.bias.data.fill_(0.1)
            m.running_var.fill_(1.1)
            m.running_mean.fill_(0.1)

    if p.batch_norm:
        # Make sure that we have something different to vanilla init. Negative
        # test will otherwise not fail
        m.apply(jiggle)

    # simple
    utils.seed(p)
    _, simple_logits = sample.simple(m,
                                     ds.transforms,
                                     decoder,
                                     n_samples=n_samples,
                                     batch_size=n_examples)
    simple_logits = torch.softmax(simple_logits.squeeze(), dim=0)

    # fast
    utils.seed(p)
    _, fast_logits, g = sample.fast(m,
                                    ds.transforms,
                                    decoder,
                                    n_samples=n_samples,
                                    batch_size=n_examples)
    fast_logits = torch.softmax(fast_logits.squeeze(), dim=0)

    assert torch.allclose(fast_logits, simple_logits)
コード例 #12
0
def test_track():
    sr = 16000
    p = model.HParams(sample_overlap_length=sr - 2**13)
    ds, ds_test = datasets.tracks("fixtures/goldberg/short.wav", 0.4, p)
    x, y = ds[0]
    x_test, y_test = ds_test[0]
    assert len(ds) == 10
    assert len(ds_test) == 6
    assert x.shape == (2, sr)
    assert y.shape == (2, sr)
    assert x_test.shape == (2, sr)
    assert y_test.shape == (2, sr)
コード例 #13
0
def test_loss_stable_across_batch_sizes():
    batch_sizes = {1: None, 100: None}
    for k in batch_sizes.keys():
        losses = []
        for i in range(100):
            p = model.HParams()
            x, y = datasets.to_tensor(datasets.StereoImpulse(k, 8, p))
            m = model.Wavenet(p)
            _, loss = m.forward(x, y)
            losses.append(loss.detach().numpy())
        batch_sizes[k] = (np.mean(losses), np.std(losses))

    means = [v[0] for v in batch_sizes.values()]
    assert np.std(means) < 0.25, batch_sizes
コード例 #14
0
def test_checkpoint():
    with helpers.tempdir() as tmp:
        p = model.HParams()
        m = model.Wavenet(p)
        tp = train.HParams()
        t = train.Trainer(m, [1, 2, 3], [4, 5], tp)
        filename = tmp / "checkpoint"
        utils.checkpoint("test", t.state(), tp, filename)
        state = torch.load(filename)
        assert "model" in state
        assert "optimizer" in state
        assert "scaler" in state
        assert "schedule" in state
        assert "epoch" in state
        assert "best" in state
        t.load_state(state)
コード例 #15
0
def test_shifted_units_generator_vs_wavenet_one_sample():
    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        n_chans=32,
        dilation_stacks=1,
        n_layers=6,
        compress=False,
    )

    m = model.Wavenet(p)
    g = sample.Generator(m)
    x = torch.rand((1, p.n_audio_chans, 1))
    ym = m.shifted.forward(x).squeeze()
    yg = g.shifted.forward(F.pad(x, (1, -1))).squeeze()  # causal
    assert torch.all(ym == yg)
コード例 #16
0
def test_loss_jacobian_full_receptive_field(embed_inputs):
    batch_size = 2
    p = model.HParams(
        embed_inputs=embed_inputs,
        n_audio_chans=1,
        n_classes=2,
        dilation_stacks=2,
        n_layers=4,
        sample_length=40,
    ).with_all_chans(10)

    m = model.Wavenet(p)

    # pin it down expected receptive field
    assert p.receptive_field_size() == 32, p.receptive_field_size()
    assert p.sample_length > p.receptive_field_size()

    # all results should be class 2
    y = torch.ones((batch_size, 1, p.sample_length), dtype=torch.long)

    def loss(x):
        logits, _ = m.forward(x)
        losses = F.cross_entropy(logits, y, reduction="none")
        return losses.sum(1)  # N, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x = torch.rand((batch_size, 1, p.sample_length))
    j = jacobian(loss, x)

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # pick the last row of the WxW jacobian. these are the derivatives of each
    # input timestep with respect to the last output timestep. we also chop
    # off the last input timestep, since this cannot have an effect on the
    # last output timestep due to temporal masking.
    receptive_field = j[-1, :-1]

    # checks out
    assert receptive_field.ne(0.0).sum() == p.receptive_field_size()

    # but let for real
    expected = torch.zeros_like(receptive_field)
    expected[-p.receptive_field_size():] = 1
    assert expected.ne(0.0).equal(receptive_field.ne(0.0))
コード例 #17
0
def test_logit_jacobian_first_sample():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 1, p)
    m = model.Wavenet(p)

    def logits(x):
        "we are only interested in the time dimensions W. keeping n for loss"
        logits, _ = m.forward(x)
        return logits.sum((1, 2))  # N, K, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(logits, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # gradients must remain unaffected by the input
    assert torch.unique(j) == torch.zeros(1)
コード例 #18
0
def test_logit_jacobian_many_samples():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 8, p)  # 8 samples
    m = model.Wavenet(p)

    def logits(x):
        "we are only interested in the time dimensions W. keeping n for loss"
        logits, _ = m.forward(x)
        return logits.sum((1, 2))  # N, K, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(logits, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # jacobian must be lower triangular
    assert torch.equal(torch.tril(j), j)
コード例 #19
0
def test_wavenet_modules_registered():
    m = model.Wavenet(model.HParams(n_layers=1, dilation_stacks=1))
    got = list(m.state_dict().keys())
    want = [
        "shifted.weight",
        "shifted.bias",
        "layers.0.conv.weight",
        "layers.0.conv.bias",
        "layers.0.res1x1.weight",
        "layers.0.res1x1.bias",
        "layers.0.skip1x1.weight",
        "layers.0.skip1x1.bias",
        "a1x1.weight",
        "a1x1.bias",
        "b1x1.weight",
        "b1x1.bias",
    ]

    assert got == want
コード例 #20
0
def test_loss_jacobian_many_samples():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 8, p)  # 8 samples
    m = model.Wavenet(p)

    def loss(x):
        logits, _ = m.forward(x)
        targets = utils.audio_to_class_idxs(x, p.n_classes)
        losses = F.cross_entropy(logits, targets, reduction="none")
        return losses.sum(1)  # N, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(loss, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # jacobian must be lower triangular
    assert torch.equal(torch.tril(j), j)
コード例 #21
0
def test_memoed_shifted_causal1d():
    """The behavior is a bit different here. Before, the first input value was
    x_0. Here, the first input value is always zero, followed by x_0, x_1,
    etc. On top of this, we have a left padding of one, as before.

    In the source network, this is implemented with a right rotation of x,
    followed by left padding.
    """

    p = model.HParams()
    utils.seed(p)  # reset seeds and use deterministic mode

    N, C, W = (1, 1, 8)
    dilation = 1
    kernel_size = 2
    conv = model.Conv1d(1,
                        1,
                        kernel_size,
                        shifted=True,
                        causal=True,
                        dilation=dilation)

    # the  expected behavior
    x = torch.rand((N, C, W))
    want = conv(x)

    # with a kernel size 2, you have to remember 1 past input element. this is
    # combined with the current element in order to compute the output.
    memoed = sample.Memo(conv)

    res = []
    x = F.pad(x, (1, -1))
    for i in range(W):
        step = memoed(x[:, :, i:i + 1])
        res.append(step)

    # want the same padding behavior as ShiftedCausal1d
    got = torch.cat(res, axis=2)
    assert torch.allclose(want, got)
コード例 #22
0
def test_tracks_overlapped_receptive_fields():

    with helpers.tempdir() as cache:
        p = model.HParams(sample_length=16000, sample_overlap_length=8000)
        root = Path("fixtures/goldberg")
        ds = datasets.Tracks.from_dir(p, root, cache_dir=cache)

        # expected audio samples duration
        def expected_duration(ds):
            duration = 0
            for t in ds.tracks:
                # prune so that examples will fit in exactly
                l = audio.prune_duration(t.duration, p)

                # length if you would pad all trailing examples that are
                # actually too long to fit into the track
                l = math.floor(l / p.sample_hop_length()) * p.sample_length

                # length of the overcounted trailing examples. these are too
                # long to be contained in the track
                l -= (math.floor(p.sample_length / p.sample_hop_length()) -
                      1) * p.sample_length

                # accumulate across all tracks
                duration += l
            return duration

        # number of audio samples across all examples in the dataset
        dataset_duration = 0
        for i in range(len(ds)):
            x, y, meta = ds[i]
            n_samples = x.shape[-1]
            dataset_duration += n_samples
            assert n_samples == p.sample_length, i

        assert expected_duration(ds) == dataset_duration
コード例 #23
0
def test_generator_init():
    m = model.Wavenet(model.HParams())
    assert sample.Generator(m)
コード例 #24
0
def test_many_samples_with_embedding():
    m = model.Wavenet(model.HParams(embed_inputs=True))
    tf = datasets.AudioUnitTransforms(m.cfg)
    track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=50)
    assert track.shape == (1, m.cfg.n_audio_chans, 50)
コード例 #25
0
def test_two_samples():
    m = model.Wavenet(model.HParams())
    tf = datasets.AudioUnitTransforms(m.cfg)
    track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=2)
    assert track.shape == (1, m.cfg.n_audio_chans, 2)
コード例 #26
0
def test_shifted_weights_generator_vs_wavenet():
    m = model.Wavenet(model.HParams())
    g = sample.Generator(m)
    assert torch.equal(m.shifted.weight, g.shifted.c.weight)
コード例 #27
0
def test_sines_dataset_stacked():
    d = datasets.Sines(4, model.HParams())
    x, y = datasets.to_tensor(d)
    assert y.shape == (4, 2, 16000)
    assert torch.min(y) >= 0.0
    assert torch.max(y) <= 256.0
コード例 #28
0
def test_sines_dataloader():
    d = datasets.Sines(10, model.HParams())
    l = torch.utils.data.dataloader.DataLoader(d, batch_size=4)
    x, y = next(iter(l))
    assert x.shape == (4, 2, 16000)
    assert y.shape == (4, 2, 16000)
コード例 #29
0
def test_maestro():
    root_dir = Path("fixtures/maestro")
    train, test = datasets.maestro(root_dir, model.HParams(), year=2018)
    assert len(train) == 1
    assert len(test) == 1
コード例 #30
0
def test_stereo_impulse_dataset():
    d = datasets.StereoImpulse(10, 4, model.HParams())
    x, y = d[0]
    assert y.shape == (2, 4)
    assert repr(d) == "StereoImpulse()"
    assert len(d) == 10