Esempio n. 1
0
def test_wavenet_batchnorm_output_shape(embed_inputs):
    p = model.HParams(embed_inputs=embed_inputs, batch_norm=True)
    m = model.Wavenet(p)
    y = torch.randint(p.n_classes, (3, p.n_audio_chans, 4))
    x = y.float()
    y_hat, _ = m.forward(x, y)
    assert y_hat.shape == (3, p.n_classes, p.n_audio_chans, 4)
Esempio n. 2
0
def test_generator_forward_one_sample():
    m = model.Wavenet(model.HParams(n_classes=2**8))
    g = sample.Generator(m)
    y = torch.randint(0, 2**8, (2, 2, 1))
    x = y.float()
    y, loss = g.forward(x, y)
    assert y.shape == (2, m.cfg.n_classes, m.cfg.n_audio_chans, 1)
Esempio n. 3
0
def test_distributed_data_parallel():
    p = model.HParams(n_audio_chans=2, n_layers=8).with_all_chans(2)
    tp = train.HParams(max_epochs=1, batch_size=8)
    ds, ds_test = datasets.tracks("fixtures/goldberg/short.wav", 0.8, p)
    m = model.Wavenet(p)
    t = distributed.DDP(m, ds, None, tp)
    t.train()
Esempio n. 4
0
def test_one_logit_generator_vs_wavenet(batch_norm):
    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        n_chans=32,
        dilation_stacks=1,
        n_layers=6,
        batch_norm=batch_norm,
        compress=False,
    )

    m = model.Wavenet(p).eval()
    g = sample.Generator(m).eval()
    x = torch.rand((10, m.cfg.n_audio_chans, 1))  # don't bn with batch size 1

    # forward pass through original with a single example
    ym, _ = m.forward(x)
    ym = F.softmax(ym.squeeze(), dim=0)

    # forward pass copy with a single example
    yg, _ = g.forward(F.pad(x, (1, -1)))  # causal
    yg = F.softmax(yg.squeeze(), dim=0)

    assert torch.all(ym == yg)
Esempio n. 5
0
def test_many_logits_generator_vs_wavenet(batch_norm, n_samples):
    """This test doesn't do any sampling. Instead we compare logits. On the
    wavenet, this can be done with a single forward pass. The generator only
    accepts a one sample input, so we have to generate that by passing through
    one piece of the input at a time while appending all the outputs.
    """

    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        dilation_stacks=1,
        n_layers=1,
        batch_norm=batch_norm,
        compress=False,
    ).with_all_chans(32)

    utils.seed(p)  # reset seeds and use deterministic mode

    # set up model and generator
    m = model.Wavenet(p).eval()

    def jiggle(m):
        if isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1.1)
            m.bias.data.fill_(0.1)
            m.running_var.fill_(1.1)
            m.running_mean.fill_(0.1)

    if p.batch_norm:
        # Make sure that we have something different to vanilla init. Negative
        # test will otherwise not fail
        m.apply(jiggle)

    # a single forward pass through the wavenet.
    x = torch.rand((1, m.cfg.n_audio_chans, n_samples))
    ym, _ = m.forward(x)

    # iterate forward on generator to accumulate all of the logits that would
    # have been output on a single forward pass of a random input. that's what
    # we see in the line above, for the underlying wavenet. the generator,
    # however, only processes a single sample at a time.
    g = sample.Generator(m).eval()
    yg = None
    x = F.pad(x, (1, -1))
    for i in range(n_samples):
        timestep = x[:, :, i:(i + 1)]
        logits, _ = g.forward(timestep)
        if yg is not None:
            yg = torch.cat([yg, logits], -1)
        else:
            yg = logits

    # posterior
    ym = F.softmax(ym.squeeze(), dim=0)
    yg = F.softmax(yg.squeeze(), dim=0)

    assert torch.allclose(ym, yg)
Esempio n. 6
0
def test_many_logits_fast_vs_simple(embed_inputs, n_audio_chans, batch_norm):

    n_samples, n_examples = 100, 5
    p = model.HParams(
        mixed_precision=False,
        embed_inputs=embed_inputs,
        n_audio_chans=n_audio_chans,
        n_classes=20,
        dilation_stacks=1,
        n_layers=2,
        batch_norm=batch_norm,
        compress=False,
        sample_length=n_samples,
        seed=135,
    ).with_all_chans(16)

    utils.seed(p)
    ds = datasets.Tiny(n_samples, n_examples)
    m = model.Wavenet(p).eval()

    def decoder(logits):
        utils.seed(p)
        return utils.decode_random(logits)

    def jiggle(m):
        if isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1.1)
            m.bias.data.fill_(0.1)
            m.running_var.fill_(1.1)
            m.running_mean.fill_(0.1)

    if p.batch_norm:
        # Make sure that we have something different to vanilla init. Negative
        # test will otherwise not fail
        m.apply(jiggle)

    # simple
    utils.seed(p)
    _, simple_logits = sample.simple(m,
                                     ds.transforms,
                                     decoder,
                                     n_samples=n_samples,
                                     batch_size=n_examples)
    simple_logits = torch.softmax(simple_logits.squeeze(), dim=0)

    # fast
    utils.seed(p)
    _, fast_logits, g = sample.fast(m,
                                    ds.transforms,
                                    decoder,
                                    n_samples=n_samples,
                                    batch_size=n_examples)
    fast_logits = torch.softmax(fast_logits.squeeze(), dim=0)

    assert torch.allclose(fast_logits, simple_logits)
Esempio n. 7
0
def test_loss_stable_across_batch_sizes():
    batch_sizes = {1: None, 100: None}
    for k in batch_sizes.keys():
        losses = []
        for i in range(100):
            p = model.HParams()
            x, y = datasets.to_tensor(datasets.StereoImpulse(k, 8, p))
            m = model.Wavenet(p)
            _, loss = m.forward(x, y)
            losses.append(loss.detach().numpy())
        batch_sizes[k] = (np.mean(losses), np.std(losses))

    means = [v[0] for v in batch_sizes.values()]
    assert np.std(means) < 0.25, batch_sizes
Esempio n. 8
0
def test_checkpoint():
    with helpers.tempdir() as tmp:
        p = model.HParams()
        m = model.Wavenet(p)
        tp = train.HParams()
        t = train.Trainer(m, [1, 2, 3], [4, 5], tp)
        filename = tmp / "checkpoint"
        utils.checkpoint("test", t.state(), tp, filename)
        state = torch.load(filename)
        assert "model" in state
        assert "optimizer" in state
        assert "scaler" in state
        assert "schedule" in state
        assert "epoch" in state
        assert "best" in state
        t.load_state(state)
Esempio n. 9
0
def test_shifted_units_generator_vs_wavenet_one_sample():
    p = model.HParams(
        mixed_precision=False,
        n_audio_chans=1,
        n_classes=16,
        n_chans=32,
        dilation_stacks=1,
        n_layers=6,
        compress=False,
    )

    m = model.Wavenet(p)
    g = sample.Generator(m)
    x = torch.rand((1, p.n_audio_chans, 1))
    ym = m.shifted.forward(x).squeeze()
    yg = g.shifted.forward(F.pad(x, (1, -1))).squeeze()  # causal
    assert torch.all(ym == yg)
Esempio n. 10
0
def test_loss_jacobian_full_receptive_field(embed_inputs):
    batch_size = 2
    p = model.HParams(
        embed_inputs=embed_inputs,
        n_audio_chans=1,
        n_classes=2,
        dilation_stacks=2,
        n_layers=4,
        sample_length=40,
    ).with_all_chans(10)

    m = model.Wavenet(p)

    # pin it down expected receptive field
    assert p.receptive_field_size() == 32, p.receptive_field_size()
    assert p.sample_length > p.receptive_field_size()

    # all results should be class 2
    y = torch.ones((batch_size, 1, p.sample_length), dtype=torch.long)

    def loss(x):
        logits, _ = m.forward(x)
        losses = F.cross_entropy(logits, y, reduction="none")
        return losses.sum(1)  # N, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x = torch.rand((batch_size, 1, p.sample_length))
    j = jacobian(loss, x)

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # pick the last row of the WxW jacobian. these are the derivatives of each
    # input timestep with respect to the last output timestep. we also chop
    # off the last input timestep, since this cannot have an effect on the
    # last output timestep due to temporal masking.
    receptive_field = j[-1, :-1]

    # checks out
    assert receptive_field.ne(0.0).sum() == p.receptive_field_size()

    # but let for real
    expected = torch.zeros_like(receptive_field)
    expected[-p.receptive_field_size():] = 1
    assert expected.ne(0.0).equal(receptive_field.ne(0.0))
Esempio n. 11
0
def test_wavenet_modules_registered():
    m = model.Wavenet(model.HParams(n_layers=1, dilation_stacks=1))
    got = list(m.state_dict().keys())
    want = [
        "shifted.weight",
        "shifted.bias",
        "layers.0.conv.weight",
        "layers.0.conv.bias",
        "layers.0.res1x1.weight",
        "layers.0.res1x1.bias",
        "layers.0.skip1x1.weight",
        "layers.0.skip1x1.bias",
        "a1x1.weight",
        "a1x1.bias",
        "b1x1.weight",
        "b1x1.bias",
    ]

    assert got == want
Esempio n. 12
0
def test_logit_jacobian_first_sample():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 1, p)
    m = model.Wavenet(p)

    def logits(x):
        "we are only interested in the time dimensions W. keeping n for loss"
        logits, _ = m.forward(x)
        return logits.sum((1, 2))  # N, K, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(logits, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # gradients must remain unaffected by the input
    assert torch.unique(j) == torch.zeros(1)
Esempio n. 13
0
def test_logit_jacobian_many_samples():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 8, p)  # 8 samples
    m = model.Wavenet(p)

    def logits(x):
        "we are only interested in the time dimensions W. keeping n for loss"
        logits, _ = m.forward(x)
        return logits.sum((1, 2))  # N, K, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(logits, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # jacobian must be lower triangular
    assert torch.equal(torch.tril(j), j)
Esempio n. 14
0
def test_loss_jacobian_many_samples():
    p = model.HParams()
    X = datasets.StereoImpulse(1, 8, p)  # 8 samples
    m = model.Wavenet(p)

    def loss(x):
        logits, _ = m.forward(x)
        targets = utils.audio_to_class_idxs(x, p.n_classes)
        losses = F.cross_entropy(logits, targets, reduction="none")
        return losses.sum(1)  # N, C, W -> N, W

    # input is N, C, W. output is N, W. jacobian is N, W, N, C, W
    x, _ = X[0]
    j = jacobian(loss, x.unsqueeze(0))

    # sum everything else to obtain WxW
    j = j.sum((0, 2, 3))

    # jacobian must be lower triangular
    assert torch.equal(torch.tril(j), j)
Esempio n. 15
0
def test_shifted_weights_generator_vs_wavenet():
    m = model.Wavenet(model.HParams())
    g = sample.Generator(m)
    assert torch.equal(m.shifted.weight, g.shifted.c.weight)
Esempio n. 16
0
def test_wavenet_dilation_stacks():
    m = model.Wavenet(model.HParams(n_layers=2, dilation_stacks=2))
    dilations = [l.conv.dilation[0] for l in m.layers]
    assert dilations == [1, 2, 1, 2]
Esempio n. 17
0
def test_wavenet_mono_input_embedded_output_shape():
    m = model.Wavenet(model.HParams(embed_inputs=True, n_audio_chans=1))
    y = torch.randint(256, (3, 1, 4))
    x = y.float()
    y_hat, _ = m.forward(x, y)
    assert y_hat.shape == (3, 256, 1, 4)
Esempio n. 18
0
def test_onecycle():
    cfg = train.HParams(batch_size=1, max_epochs=1)
    m = model.Wavenet(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate)
    schedule = utils.onecycle(optimizer, 9, cfg)
    assert schedule.total_steps == 9
Esempio n. 19
0
def test_generator_init():
    m = model.Wavenet(model.HParams())
    assert sample.Generator(m)
Esempio n. 20
0
def test_two_samples():
    m = model.Wavenet(model.HParams())
    tf = datasets.AudioUnitTransforms(m.cfg)
    track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=2)
    assert track.shape == (1, m.cfg.n_audio_chans, 2)
Esempio n. 21
0
def test_many_samples_with_embedding():
    m = model.Wavenet(model.HParams(embed_inputs=True))
    tf = datasets.AudioUnitTransforms(m.cfg)
    track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=50)
    assert track.shape == (1, m.cfg.n_audio_chans, 50)
Esempio n. 22
0
def test_wavenet_output_shape():
    m = model.Wavenet(model.HParams())
    y = torch.randint(256, (3, 2, 4))
    x = y.float()
    y_hat, _ = m.forward(x, y)
    assert y_hat.shape == (3, 256, 2, 4)
Esempio n. 23
0
def test_wavenet_mono_output_shape():
    m = model.Wavenet(model.HParams(n_audio_chans=1))
    y = torch.randint(256, (3, 1, 4))
    x = y.float()
    x, _ = m.forward(x, y)
    assert x.shape == (3, 256, 1, 4)
Esempio n. 24
0
def test_lrfinder():
    m = model.Wavenet(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=1e-8)
    p = train.HParams(batch_size=1, max_epochs=1)
    schedule = utils.lrfinder(optimizer, 9, p)
    assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.0))