def test_dataset_loader(): from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess from jukebox.hparams import setup_hparams from jukebox.data.files_dataset import FilesAudioDataset hps = setup_hparams("teeny", {}) hps.sr = 22050 # 44100 hps.hop_length = 512 hps.labels = False hps.channels = 2 hps.aug_shift = False hps.bs = 2 hps.nworkers = 2 # Getting 20 it/s with 2 workers, 10 it/s with 1 worker print(hps) dataset = hps.dataset root = hps.root from tensorboardX import SummaryWriter sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr] writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs') dataset = FilesAudioDataset(hps) print("Length of dataset", len(dataset)) # Torch Loader collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0) sampler = DistributedSampler(dataset) train_loader = DataLoader(dataset, batch_size=hps.bs, num_workers=hps.nworkers, pin_memory=False, sampler=sampler, drop_last=True, collate_fn=collate_fn) dist.barrier() sampler.set_epoch(0) for i, x in enumerate(tqdm(train_loader)): x = x.to('cuda', non_blocking=True) for j, aud in enumerate(x): writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote in") x = audio_preprocess(x, hps) x = audio_postprocess(x, hps) for j, aud in enumerate(x): writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote out") dist.barrier() break
def forward(self, x, hps, loss_fn='l1'): metrics = {} N = x.shape[0] # Encode/Decode x_in = self.preprocess(x) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck( xs) x_outs = [] for level in range(self.levels): decoder = self.decoders[level] x_out = decoder(xs_quantised[level:level + 1], all_levels=False) assert_shape(x_out, x_in.shape) x_outs.append(x_out) # Loss def _spectral_loss(x_target, x_out, hps): if hps.use_nonrelative_specloss: sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec'] else: sl = spectral_convergence(x_target, x_out, hps) sl = t.mean(sl) return sl def _multispectral_loss(x_target, x_out, hps): sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec'] sl = t.mean(sl) return sl recons_loss = t.zeros(()).to(x.device) spec_loss = t.zeros(()).to(x.device) multispec_loss = t.zeros(()).to(x.device) x_target = audio_postprocess(x.float(), hps) for level in reversed(range(self.levels)): x_out = self.postprocess(x_outs[level]) x_out = audio_postprocess(x_out, hps) this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps) this_spec_loss = _spectral_loss(x_target, x_out, hps) this_multispec_loss = _multispectral_loss(x_target, x_out, hps) metrics[f'recons_loss_l{level + 1}'] = this_recons_loss metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss recons_loss += this_recons_loss spec_loss += this_spec_loss multispec_loss += this_multispec_loss commit_loss = sum(commit_losses) loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss with t.no_grad(): sc = t.mean(spectral_convergence(x_target, x_out, hps)) l2_loss = _loss_fn("l2", x_target, x_out, hps) l1_loss = _loss_fn("l1", x_target, x_out, hps) linf_loss = _loss_fn("linf", x_target, x_out, hps) quantiser_metrics = average_metrics(quantiser_metrics) metrics.update( dict(recons_loss=recons_loss, spectral_loss=spec_loss, multispectral_loss=multispec_loss, spectral_convergence=sc, l2_loss=l2_loss, l1_loss=l1_loss, linf_loss=linf_loss, commit_loss=commit_loss, **quantiser_metrics)) for key, val in metrics.items(): metrics[key] = val.detach() return x_out, loss, metrics
def prepare_aud(x, hps): x = audio_postprocess(x.detach().contiguous(), hps) return allgather(x)