def test_wavenet_batchnorm_output_shape(embed_inputs): p = model.HParams(embed_inputs=embed_inputs, batch_norm=True) m = model.Wavenet(p) y = torch.randint(p.n_classes, (3, p.n_audio_chans, 4)) x = y.float() y_hat, _ = m.forward(x, y) assert y_hat.shape == (3, p.n_classes, p.n_audio_chans, 4)
def test_generator_forward_one_sample(): m = model.Wavenet(model.HParams(n_classes=2**8)) g = sample.Generator(m) y = torch.randint(0, 2**8, (2, 2, 1)) x = y.float() y, loss = g.forward(x, y) assert y.shape == (2, m.cfg.n_classes, m.cfg.n_audio_chans, 1)
def test_distributed_data_parallel(): p = model.HParams(n_audio_chans=2, n_layers=8).with_all_chans(2) tp = train.HParams(max_epochs=1, batch_size=8) ds, ds_test = datasets.tracks("fixtures/goldberg/short.wav", 0.8, p) m = model.Wavenet(p) t = distributed.DDP(m, ds, None, tp) t.train()
def test_one_logit_generator_vs_wavenet(batch_norm): p = model.HParams( mixed_precision=False, n_audio_chans=1, n_classes=16, n_chans=32, dilation_stacks=1, n_layers=6, batch_norm=batch_norm, compress=False, ) m = model.Wavenet(p).eval() g = sample.Generator(m).eval() x = torch.rand((10, m.cfg.n_audio_chans, 1)) # don't bn with batch size 1 # forward pass through original with a single example ym, _ = m.forward(x) ym = F.softmax(ym.squeeze(), dim=0) # forward pass copy with a single example yg, _ = g.forward(F.pad(x, (1, -1))) # causal yg = F.softmax(yg.squeeze(), dim=0) assert torch.all(ym == yg)
def test_many_logits_generator_vs_wavenet(batch_norm, n_samples): """This test doesn't do any sampling. Instead we compare logits. On the wavenet, this can be done with a single forward pass. The generator only accepts a one sample input, so we have to generate that by passing through one piece of the input at a time while appending all the outputs. """ p = model.HParams( mixed_precision=False, n_audio_chans=1, n_classes=16, dilation_stacks=1, n_layers=1, batch_norm=batch_norm, compress=False, ).with_all_chans(32) utils.seed(p) # reset seeds and use deterministic mode # set up model and generator m = model.Wavenet(p).eval() def jiggle(m): if isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1.1) m.bias.data.fill_(0.1) m.running_var.fill_(1.1) m.running_mean.fill_(0.1) if p.batch_norm: # Make sure that we have something different to vanilla init. Negative # test will otherwise not fail m.apply(jiggle) # a single forward pass through the wavenet. x = torch.rand((1, m.cfg.n_audio_chans, n_samples)) ym, _ = m.forward(x) # iterate forward on generator to accumulate all of the logits that would # have been output on a single forward pass of a random input. that's what # we see in the line above, for the underlying wavenet. the generator, # however, only processes a single sample at a time. g = sample.Generator(m).eval() yg = None x = F.pad(x, (1, -1)) for i in range(n_samples): timestep = x[:, :, i:(i + 1)] logits, _ = g.forward(timestep) if yg is not None: yg = torch.cat([yg, logits], -1) else: yg = logits # posterior ym = F.softmax(ym.squeeze(), dim=0) yg = F.softmax(yg.squeeze(), dim=0) assert torch.allclose(ym, yg)
def test_many_logits_fast_vs_simple(embed_inputs, n_audio_chans, batch_norm): n_samples, n_examples = 100, 5 p = model.HParams( mixed_precision=False, embed_inputs=embed_inputs, n_audio_chans=n_audio_chans, n_classes=20, dilation_stacks=1, n_layers=2, batch_norm=batch_norm, compress=False, sample_length=n_samples, seed=135, ).with_all_chans(16) utils.seed(p) ds = datasets.Tiny(n_samples, n_examples) m = model.Wavenet(p).eval() def decoder(logits): utils.seed(p) return utils.decode_random(logits) def jiggle(m): if isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1.1) m.bias.data.fill_(0.1) m.running_var.fill_(1.1) m.running_mean.fill_(0.1) if p.batch_norm: # Make sure that we have something different to vanilla init. Negative # test will otherwise not fail m.apply(jiggle) # simple utils.seed(p) _, simple_logits = sample.simple(m, ds.transforms, decoder, n_samples=n_samples, batch_size=n_examples) simple_logits = torch.softmax(simple_logits.squeeze(), dim=0) # fast utils.seed(p) _, fast_logits, g = sample.fast(m, ds.transforms, decoder, n_samples=n_samples, batch_size=n_examples) fast_logits = torch.softmax(fast_logits.squeeze(), dim=0) assert torch.allclose(fast_logits, simple_logits)
def test_loss_stable_across_batch_sizes(): batch_sizes = {1: None, 100: None} for k in batch_sizes.keys(): losses = [] for i in range(100): p = model.HParams() x, y = datasets.to_tensor(datasets.StereoImpulse(k, 8, p)) m = model.Wavenet(p) _, loss = m.forward(x, y) losses.append(loss.detach().numpy()) batch_sizes[k] = (np.mean(losses), np.std(losses)) means = [v[0] for v in batch_sizes.values()] assert np.std(means) < 0.25, batch_sizes
def test_checkpoint(): with helpers.tempdir() as tmp: p = model.HParams() m = model.Wavenet(p) tp = train.HParams() t = train.Trainer(m, [1, 2, 3], [4, 5], tp) filename = tmp / "checkpoint" utils.checkpoint("test", t.state(), tp, filename) state = torch.load(filename) assert "model" in state assert "optimizer" in state assert "scaler" in state assert "schedule" in state assert "epoch" in state assert "best" in state t.load_state(state)
def test_shifted_units_generator_vs_wavenet_one_sample(): p = model.HParams( mixed_precision=False, n_audio_chans=1, n_classes=16, n_chans=32, dilation_stacks=1, n_layers=6, compress=False, ) m = model.Wavenet(p) g = sample.Generator(m) x = torch.rand((1, p.n_audio_chans, 1)) ym = m.shifted.forward(x).squeeze() yg = g.shifted.forward(F.pad(x, (1, -1))).squeeze() # causal assert torch.all(ym == yg)
def test_loss_jacobian_full_receptive_field(embed_inputs): batch_size = 2 p = model.HParams( embed_inputs=embed_inputs, n_audio_chans=1, n_classes=2, dilation_stacks=2, n_layers=4, sample_length=40, ).with_all_chans(10) m = model.Wavenet(p) # pin it down expected receptive field assert p.receptive_field_size() == 32, p.receptive_field_size() assert p.sample_length > p.receptive_field_size() # all results should be class 2 y = torch.ones((batch_size, 1, p.sample_length), dtype=torch.long) def loss(x): logits, _ = m.forward(x) losses = F.cross_entropy(logits, y, reduction="none") return losses.sum(1) # N, C, W -> N, W # input is N, C, W. output is N, W. jacobian is N, W, N, C, W x = torch.rand((batch_size, 1, p.sample_length)) j = jacobian(loss, x) # sum everything else to obtain WxW j = j.sum((0, 2, 3)) # pick the last row of the WxW jacobian. these are the derivatives of each # input timestep with respect to the last output timestep. we also chop # off the last input timestep, since this cannot have an effect on the # last output timestep due to temporal masking. receptive_field = j[-1, :-1] # checks out assert receptive_field.ne(0.0).sum() == p.receptive_field_size() # but let for real expected = torch.zeros_like(receptive_field) expected[-p.receptive_field_size():] = 1 assert expected.ne(0.0).equal(receptive_field.ne(0.0))
def test_wavenet_modules_registered(): m = model.Wavenet(model.HParams(n_layers=1, dilation_stacks=1)) got = list(m.state_dict().keys()) want = [ "shifted.weight", "shifted.bias", "layers.0.conv.weight", "layers.0.conv.bias", "layers.0.res1x1.weight", "layers.0.res1x1.bias", "layers.0.skip1x1.weight", "layers.0.skip1x1.bias", "a1x1.weight", "a1x1.bias", "b1x1.weight", "b1x1.bias", ] assert got == want
def test_logit_jacobian_first_sample(): p = model.HParams() X = datasets.StereoImpulse(1, 1, p) m = model.Wavenet(p) def logits(x): "we are only interested in the time dimensions W. keeping n for loss" logits, _ = m.forward(x) return logits.sum((1, 2)) # N, K, C, W -> N, W # input is N, C, W. output is N, W. jacobian is N, W, N, C, W x, _ = X[0] j = jacobian(logits, x.unsqueeze(0)) # sum everything else to obtain WxW j = j.sum((0, 2, 3)) # gradients must remain unaffected by the input assert torch.unique(j) == torch.zeros(1)
def test_logit_jacobian_many_samples(): p = model.HParams() X = datasets.StereoImpulse(1, 8, p) # 8 samples m = model.Wavenet(p) def logits(x): "we are only interested in the time dimensions W. keeping n for loss" logits, _ = m.forward(x) return logits.sum((1, 2)) # N, K, C, W -> N, W # input is N, C, W. output is N, W. jacobian is N, W, N, C, W x, _ = X[0] j = jacobian(logits, x.unsqueeze(0)) # sum everything else to obtain WxW j = j.sum((0, 2, 3)) # jacobian must be lower triangular assert torch.equal(torch.tril(j), j)
def test_loss_jacobian_many_samples(): p = model.HParams() X = datasets.StereoImpulse(1, 8, p) # 8 samples m = model.Wavenet(p) def loss(x): logits, _ = m.forward(x) targets = utils.audio_to_class_idxs(x, p.n_classes) losses = F.cross_entropy(logits, targets, reduction="none") return losses.sum(1) # N, C, W -> N, W # input is N, C, W. output is N, W. jacobian is N, W, N, C, W x, _ = X[0] j = jacobian(loss, x.unsqueeze(0)) # sum everything else to obtain WxW j = j.sum((0, 2, 3)) # jacobian must be lower triangular assert torch.equal(torch.tril(j), j)
def test_shifted_weights_generator_vs_wavenet(): m = model.Wavenet(model.HParams()) g = sample.Generator(m) assert torch.equal(m.shifted.weight, g.shifted.c.weight)
def test_wavenet_dilation_stacks(): m = model.Wavenet(model.HParams(n_layers=2, dilation_stacks=2)) dilations = [l.conv.dilation[0] for l in m.layers] assert dilations == [1, 2, 1, 2]
def test_wavenet_mono_input_embedded_output_shape(): m = model.Wavenet(model.HParams(embed_inputs=True, n_audio_chans=1)) y = torch.randint(256, (3, 1, 4)) x = y.float() y_hat, _ = m.forward(x, y) assert y_hat.shape == (3, 256, 1, 4)
def test_onecycle(): cfg = train.HParams(batch_size=1, max_epochs=1) m = model.Wavenet(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate) schedule = utils.onecycle(optimizer, 9, cfg) assert schedule.total_steps == 9
def test_generator_init(): m = model.Wavenet(model.HParams()) assert sample.Generator(m)
def test_two_samples(): m = model.Wavenet(model.HParams()) tf = datasets.AudioUnitTransforms(m.cfg) track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=2) assert track.shape == (1, m.cfg.n_audio_chans, 2)
def test_many_samples_with_embedding(): m = model.Wavenet(model.HParams(embed_inputs=True)) tf = datasets.AudioUnitTransforms(m.cfg) track, *_ = sample.fast(m, tf, utils.decode_random, n_samples=50) assert track.shape == (1, m.cfg.n_audio_chans, 50)
def test_wavenet_output_shape(): m = model.Wavenet(model.HParams()) y = torch.randint(256, (3, 2, 4)) x = y.float() y_hat, _ = m.forward(x, y) assert y_hat.shape == (3, 256, 2, 4)
def test_wavenet_mono_output_shape(): m = model.Wavenet(model.HParams(n_audio_chans=1)) y = torch.randint(256, (3, 1, 4)) x = y.float() x, _ = m.forward(x, y) assert x.shape == (3, 256, 1, 4)
def test_lrfinder(): m = model.Wavenet(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=1e-8) p = train.HParams(batch_size=1, max_epochs=1) schedule = utils.lrfinder(optimizer, 9, p) assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.0))